In [None]:
import requests
import private_set_intersection.python as psi
from pandas import read_csv

In [None]:
url="http://localhost:5000/"

## Client Setup

In [None]:
df_m = read_csv('mari_clean.csv')
client_items = (df_m['CompanyName']+' '+df_m['Postcode']).to_list()

c = psi.client.CreateWithNewKey(True)
psirequest = c.CreateRequest(client_items).SerializeToString()
c.CreateRequest(client_items)

### Get Server-encrypted Client Values

In [None]:
response = requests.post(url+'match', headers={'Content-Type': 'application/protobuf'}, data=psirequest)
psiresponse = psi.Response()
psiresponse.ParseFromString(response.content)
psiresponse

### Get Server Setup - Raw

In [None]:
setupresponse = requests.get(url+'rawsetup')
rawsetup = psi.ServerSetup()
rawsetup.ParseFromString(setupresponse.content)
rawsetup

## Bloom Decode

In [None]:
setupresponse = requests.get(url+'bloomsetup')
bloomsetup = psi.ServerSetup()
bloomsetup.ParseFromString(setupresponse.content)
bloomsetup

### Server Calculation

In [None]:
from math import ceil, log, log2

fpr = 0.01
num_client_inputs = 100
correctedfpr = fpr/num_client_inputs
len_server_items = 2
max_elements = max(num_client_inputs, len_server_items)
num_bits = (ceil(-max_elements * log2(correctedfpr) / log(2) /8 )) * 8
num_bits

In [None]:
from hashlib import sha256

#num_bits = len(bloomsetup.bloom_filter.bits)*8
filterlist = ['0'] * num_bits

for element in rawsetup.raw.encrypted_elements:
    element1 = str.encode('1') + element
    k = sha256(element1).hexdigest()
    h1 = int(k,16) % num_bits

    element2 = str.encode('2') + element
    k = sha256(element2).hexdigest()
    h2 = int(k,16) % num_bits
    
    for i in range(bloomsetup.bloom_filter.num_hash_functions):
        pos = ((h1 + i * h2) % num_bits)
        filterlist[num_bits-1-pos]='1'
        
filterstring = ''.join(filterlist)

In [None]:
bloombits = ''.join(format(byte, '08b') for byte in reversed(bloomsetup.bloom_filter.bits))
bloombits == filterstring

In [None]:
num_hash_functions = ceil(-log2(correctedfpr))
num_hash_functions

## GCS Decode

In [None]:
setupresponse = requests.get(url+'gcssetup')
gcssetup = psi.ServerSetup()
gcssetup.ParseFromString(setupresponse.content)
gcssetup

In [None]:
from math import ceil, log, log2

fpr = 0.01
num_client_inputs = 100
correctedfpr = fpr/num_client_inputs

hash_range = max_elements/correctedfpr
hash_range

In [None]:
from hashlib import sha256

# For all server encrypted elements, calculate hash and then bucket value
ulist = []
for element in rawsetup.raw.encrypted_elements:
    k = sha256(element).hexdigest()
    ks = int(k,16) % gcssetup.gcs.hash_range
    ulist.append(ks)

# Sort the hash bucket values
ulist.sort()
# Calculate deltas between sorted hash bucket values 
udiff = [ulist[0]] + [ulist[n]-ulist[n-1] for n in range(1,len(ulist))]

In [None]:
avg = (ulist[-1]+1)/len(ulist)
prob = 1/avg
gcsdiv = max(0,round(-log2(-log2(1.0-prob))))
gcsdiv

In [None]:
# For all delta hash bucket values encode as unary portion for quotient followed by binary for remainder.
# Pad with leading zeros so binary portion is of consistent length.
# Concatenate with previous values

encoded = ''
for diff in udiff:
    if diff != 0:
        quot = int(diff / pow(2,gcssetup.gcs.div)) 
        rem = diff % pow(2,gcssetup.gcs.div)
        next = '{0:b}'.format(rem) + '1' + ('0' * quot)
        pad = next.zfill(quot+gcssetup.gcs.div+1)
        encoded = pad + encoded

In [None]:
# Pad final encoded string with leading 0s to length as a multiple of 8 

from math import ceil

padlength = ceil(len(encoded)/8)*8
padded = encoded.zfill(padlength)

In [None]:
# Build gcs as concatenated sequence of bits from reversed gcs.bits value returned from setup
# Check server gcs bits match our gcs bits 

gcsbits = ''.join(format(byte, '08b') for byte in reversed(gcssetup.gcs.bits))
gcsbits == padded

### Calculate Set Intersection

In [None]:
intersection = c.GetIntersection(gcssetup, psiresponse)
#intersection = c.GetIntersection(bloomsetup, psiresponse)
#intersection = c.GetIntersection(rawsetup, psiresponse)

iset = set(intersection)
sorted(intersection)

In [None]:
for index in sorted(intersection):
    print(client_items[index])