<a href="https://colab.research.google.com/github/resilientmax/60DaysOfUdacity/blob/master/Encrypted%20Database.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install syft



In [2]:
import string
import syft as sy
import torch as th
hook = sy.TorchHook(th)

W0712 04:42:19.009653 140103004968832 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/usr/local/lib/python3.6/dist-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0712 04:42:19.028486 140103004968832 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [0]:
bob = sy.VirtualWorker(hook, id="bob").add_worker(sy.local_worker)
alice = sy.VirtualWorker(hook, id="alice").add_worker(sy.local_worker)
secure_worker = sy.VirtualWorker(hook, id="secure_worker").add_worker(sy.local_worker)

In [0]:
char2index = {}
index2char = {}

In [0]:
for i,char in enumerate(' ' + string.ascii_lowercase + '0123456789' + string.punctuation):
    char2index[char] = i
    index2char[i] = char

In [0]:
str_input = "Hello"
max_len = 8

In [0]:
def string2values(str_input, max_len=8):

    str_input = str_input[:max_len].lower()

    # pad strings shorter than max len
    if(len(str_input) < max_len):
        str_input = str_input + "." * (max_len - len(str_input))

    values = list()
    for char in str_input:
        values.append(char2index[char])

    return th.tensor(values).long()


In [0]:
def values2string(input_values):
    s = ""
    for value in input_values:
        s += index2char[int(value)]
    return s

In [0]:
def strings_equal(str_a, str_b):

    vect = (str_a * str_b).sum(1)

    x = vect[0]

    for i in range(vect.shape[0] - 1):
        x = x * vect[i + 1]    

    return x

In [0]:
def one_hot(index, length):
    vect = th.zeros(length).long()
    vect[index] = 1
    return vect

In [0]:
def string2one_hot_matrix(str_input, max_len=8):

    str_input = str_input[:max_len].lower()

    # pad strings shorter than max len
    if(len(str_input) < max_len):
        str_input = str_input + "." * (max_len - len(str_input))

    char_vectors = list()
    for char in str_input:
        char_v = one_hot(char2index[char], len(char2index)).unsqueeze(0)
        char_vectors.append(char_v)
        
    return th.cat(char_vectors, dim=0)

In [0]:
class EncryptedDB():
    
    def __init__(self, *owners, max_key_len=8, max_val_len=8):
        self.max_key_len = max_key_len
        self.max_val_len = max_val_len
        
        self.keys = list()
        self.values = list()
        self.owners = owners
        
    def add_entry(self, key, value):
        key = string2one_hot_matrix(key)
        key = key.share(*self.owners)
        self.keys.append(key)
        
        value = string2values(value, max_len=self.max_val_len)
        value = value.share(*self.owners)
        self.values.append(value)
        
    def query(self, query_str):
        query_matrix = string2one_hot_matrix(query_str)
        
        query_matrix = query_matrix.share(*self.owners)

        key_matches = list()
        for key in self.keys:

            key_match = strings_equal(key, query_matrix)
            key_matches.append(key_match)

        result = self.values[0] * key_matches[0]

        for i in range(len(self.values) - 1):
            result += self.values[i+1] * key_matches[i+1]
            
        result = result.get()

        return values2string(result).replace(".","")

In [15]:
db = EncryptedDB(bob, alice, secure_worker, max_val_len=256)

db.add_entry("Bob","(123) 456 7890")
db.add_entry("Bill", "(234) 567 8901")
db.add_entry("Sam","(345) 678 9012")

db.query("Bob")

'(123) 456 7890'

In [16]:
db.values

[(Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:58860188665 -> bob:84084823858]
 	-> (Wrapper)>[PointerTensor | me:58558449843 -> alice:83728274630]
 	-> (Wrapper)>[PointerTensor | me:51095693382 -> secure_worker:25706007447]
 	*crypto provider: me*, (Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:44927532371 -> bob:49781037744]
 	-> (Wrapper)>[PointerTensor | me:91023334872 -> alice:92727767920]
 	-> (Wrapper)>[PointerTensor | me:88006171478 -> secure_worker:6767931762]
 	*crypto provider: me*, (Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:38008175604 -> bob:29848149539]
 	-> (Wrapper)>[PointerTensor | me:19706327654 -> alice:88862160510]
 	-> (Wrapper)>[PointerTensor | me:24263193903 -> secure_worker:66618366080]
 	*crypto provider: me*]