In [10]:
#imports
import syft as sy 
import openmined_psi as psi
import pickle
import torch 
from torch.utils.data import Dataset
from abc import ABC, abstractmethod

In [11]:
#load the data that we previous saved
#data are lists

index_list = pickle.load(open("indexlist", 'rb'), encoding='utf-8')
label_list = pickle.load(open("labellist", 'rb'), encoding='utf-8')
data_list = pickle.load(open("0th_owner", 'rb'), encoding='utf-8')

In [6]:
class BaseSet(Dataset):
    def __init__(self, ids, values, is_labels=False):
        self.values_dic = {}
        for i, l in zip(ids, values):
            self.values_dic[i] = l
        self.is_labels = is_labels

        self.ids = ids
        self.values = torch.Tensor(values) if is_labels else torch.stack(values)
    
    def __getitem__(self, index):
        """
        Args:
            idx: index of the example we want to get 
        Returns: a tuple with data, label, index of a single example.
        """
        return tuple([self.values[index], self.ids[index]])
    
    def __len__(self):
        """
        Returns: amount of samples in the dataset
        """
        return self.values.shape[0]
    
class SampleSetWithLabels(Dataset):
    def __init__(self, labelset, sampleset, worker_id=None):
        self.labelset = labelset
        self.sampleset = sampleset 
        
        self.labels = labelset.values
        self.values = sampleset.values
        self.ids = sampleset.ids
        
        self.values_dic = {}
        for k in labelset.values_dic.keys():
            self.values_dic[k] = tuple([sampleset.values_dic[k], labelset.values_dic[k]])
                                       
    def __getitem__(self, index):
        """
        Args: 
            idx: index of the example we want to get 
        Returns: a tuple with data, label, index of a single example.
        """
        return tuple([self.values[index], self.labels[index], self.ids[index]])

    def __len__(self):
        """
        Returns: amount of samples in the dataset
        """
        return self.values.shape[0]

In [8]:
class PsiProtocol(ABC):
    def __init__(self, duet, dataset, data_ids, fpr=1e-6):
        self.duet = duet
        self.dataset = dataset
        self.data_ids = list(map(str, list(map(int, dataset.ids))))
        self.fpr = fpr
        super().__init__()
        
        self._start_protocol()
    
    @abstractmethod
    def _start_protocol(self):
        self._one_to_one_exchange()
            
    def _one_to_one_exchange():
        self._add_handler(self.duet, name="reveal intersection")
        self._add_handler(self.duet, name="fpr")
        
        reveal_intersection = True
        sy_reveal_intersection = sy.lib.python.Bool(reveal_intersection)
        sy_reveal_intersection_ptr = sy_reveal_intersection.tag("reveal_intersection").send(self.duet, searchable=True)
        
        sy_fpr = sy.lib.python.Float(self.fpr)
        sy_fpr_ptr = sy_fpr.tag("fpr").send(self.duet, searchable=True)
        
        #client items len
        client_items_len = self._get_object_duet(tag="client_items_len")
        
        #server
        self.server = psi.server.CreateWithNewKey(reveal_intersection)
        setup = self.server.CreateSetupMessage(self.fpr, client_items_len, self.data_ids)
        
        self._add_handler_accept(self.duet, name="setup")
        
        setup_ptr = setup.send(self.duet, searchable=True, tags=["setup"], description="psi.server Setup Message")
        
        #get the request
        request_ptr = self._get_object_duet(tag="request")
        
        #response
        response = server.ProcessRequest(request)
        
        self._add_handler_accept(duet, name="response")
        
        response_ptr = response.send(duet, searchable=True, tags=["response"], description="psi.server response")
        
    def _one_to_one_exchange_client():
        pass
        
    def _add_handler_accept(duet, action="accept", name=""):
        duet.requests.add_handler(
        name=name,
        action="accept")
        
        
    def _get_object_duet(tag=""):
        
        while True: 
            try:
                self.duet.store[tag]
            except:
                continue
            break

        return duet.store[tag].get(delete_obj=False)
    
    
    def _get_ids_and_share():
        #let's get Data Scientist's ids for global intersection
        id_int = self._get_object_duet(tag="ids_intersec")
        
        #map the ids to a list of integers
        id_int_list = list(map(int, list(id_int)))
        
        #convert the values to share in tensors
        value_tensor, label_tensor, id_tensor = self._convert_values_toshare(id_int_list)
        
        #share those values
        value_tensor_ptr = value_tensor.send(self.duet, searchable=True, tags=["values"], description="intersecting values")
        label_tensor_ptr = label_tensor.send(self.duet, searchable=True, tags=["labels"], description="intersecting labels")
        id_tensor_ptr = id_tensor.send(self.duet, searchable=True, tags=["ids"], description="intersecting ids")
        
    
    def _convert_values_toshare(id_int_list):
        value_list_toshare = []
        label_list_toshare = []
        id_list_toshare = []
        for k in self.dataset.values_dic.keys():
            if int(k) in id_int_list:
                tuple_ = self.dataset.values_dic[k]
                value_list_toshare.append(tuple_[0])
                label_list_toshare.append(tuple_[1])
                id_list_toshare.append(int(k))

        value_tensor = torch.cat(value_list_toshare)
        label_tensor = torch.Tensor(label_list_toshare)
        id_tensor = torch.Tensor(id_list_toshare)

        return value_tensor, label_tensor, id_tensor

In [9]:
class PsiOneToOne(PsiProtocol): 
    def _start_protocol(self):
        super._start_protocol(self)

        
class PsiStar(PsiProtocol):
    def _start_protocol(self):
        super._start_protocol(self)
        
        super._get_ids_and_share()