In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time
import syft
from syft import WebsocketClientWorker
from syft import WebsocketServerWorker

In [2]:
hook = syft.TorchHook(torch)

In [3]:
class Parser:
    """Parameters for training"""
    def __init__(self):
        self.epochs = 10
        self.lr = 0.01
        self.batch_size = 8
        self.test_batch_size = 8
        self.log_interval = 10
        self.seed = 1

args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

In [4]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [5]:
iris = load_iris()

In [6]:
def prepare_data(dataset):
    X = dataset.data
    y = dataset.target
    X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=0)

    X_train = torch.from_numpy(X).float()
    y_train = torch.from_numpy(y).float()
    X_test = torch.from_numpy(X).float()
    y_test = torch.from_numpy(y).float()
    
    train = TensorDataset(X_train, y_train)
    test = TensorDataset(X_test, y_test)
    train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=True, **kwargs)
    
    return train_loader, test_loader

In [7]:
train_loader, test_loader = prepare_data(iris)


In [8]:
class Net(nn.Module):
    # define nn
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 8)
        self.fc2 = nn.Linear(8, 4)
        self.fc3 = nn.Linear(4, 1)

    def forward(self, x):
        x = x.view(-1, 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [9]:
#bob = syft.VirtualWorker(hook, id="bob")
#alice = syft.VirtualWorker(hook, id="alice")

alice = WebsocketClientWorker(host="localhost",hook=hook,id="alice",port=8182)
bob   = WebsocketClientWorker(host="localhost",hook=hook,id="bob",port=8183)

secure_worker = WebsocketServerWorker(host="localhost",hook=hook,id="secure_worker",port=8181)

hook = syft.TorchHook(torch, local_worker=secure_worker)
hook.local_worker.add_worker(alice)
hook.local_worker.add_worker(bob)


compute_nodes = [bob, alice]



In [10]:
def init_remote_dataset():

    # remote dataset contains 2 lists, of pointers to Alices and Bobs datasets respectively
    # TODO redo this to work with arbitrary numbers of workers
    remote_dataset = (list(), list())

    for batch_index, (data,target) in enumerate(train_loader):

        # Send data,target to the worker_id by doing modulo num_workers
        data = data.send(compute_nodes[batch_index % len(compute_nodes)])      
        target = target.send(compute_nodes[batch_index % len(compute_nodes)])

        # remote_dataset is  list of pointers to the locations of each worker's list of data,target
        remote_dataset[batch_index % len(compute_nodes)].append((data,target))
        
    return remote_dataset
remote_dataset = init_remote_dataset()

In [11]:
# performs one remote training update on the model using remote data
def update(data, target, model, optimizer):
    
    # send the model to the data owner
    model = model.send(data.location)
    
    # Perform training update
    optimizer.zero_grad()
    pred = model(data)
    loss = F.mse_loss(pred.view(-1), target)
    #print("\n",loss)
    loss = loss.get()
    #print(loss, "\n")
    # TODO: Figure out why step() doesn't appear to be updating the model
    loss.backward()
    optimizer.step()
    
    # Retrieve model from data owner
    model = model.get()
        
    return model

In [12]:
def update_hack(data, target, model, optimizer):
    
    owner = data.location
    
    # Retrieve data from data owners
    data = data.get()
    target = target.get()
    
    # Perform training update
    optimizer.zero_grad()
    pred = model(data)
    loss = F.mse_loss(pred.view(-1), target)
    
    # TODO: Figure out why this isn't updating the model
    loss.backward()
    optimizer.step()
    
    data = data.send(owner)
    target = target.send(owner)
       
    return model

In [13]:
# Init models & optimizers for each remote worker, and combine in a list
# TODO: rewrite this to work with arbitrary number of workers
alices_model = Net()
bobs_model = Net()

alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)
bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)

models = [alices_model, bobs_model]
optimizers = [bobs_optimizer, alices_optimizer]

params = [list(bobs_model.parameters()), list(alices_model.parameters())]

In [14]:
def train():
    
    remote_dataset = init_remote_dataset()
    
    for data_index in range(len(remote_dataset[0])-1):
        for remote_index in range(len(compute_nodes)):
            
            # Retrieve data,target pointers from remote dataset
            data, target = remote_dataset[remote_index][data_index]
            
            # Update the respective model
            #TODO: Figure out why this doesn't work (see update function def)
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])
            #models[remote_index] = update_hack(data, target, models[remote_index], optimizers[remote_index])

    # New list to hold aggregated parameters for each layer of our new model
    # TODO: rewrite this part to work with arbitrary number of workers and remove hardcoded division by 2
    param_agg = list()
        
    for layer_index in range(len(params[0])):
            
            # Copy both workers' parameter layers, convert to fixed-prec and encrypt using additive secret-sharing, then add together
            #param_agg.append(params[0][layer_index].data.copy().fix_prec().share(bob, alice, crypto_provider=secure_worker) + params[1][layer_index].data.copy().fix_prec().share(bob, alice, crypto_provider=secure_worker))
            param_agg.append(params[0][layer_index].data.copy() + params[1][layer_index].data.copy())
            
            # Decrypt the summed parameter layer, convert back to float-prec, and divide by 2 to get the average of both models
            #param_agg[layer_index] = param_agg[layer_index].get().float_precision()/2
            param_agg[layer_index] = param_agg[layer_index]/2
            
    
    # Disable autograd and replace worker parameters with new aggregated parameters
    with torch.no_grad():
        
        # Zero all model parameters
        for model in params:
            for param in model:
                param *= 0
        
        # Set each parameter to the value of its aggregated counterpart
        for remote_index in range(len(compute_nodes)):
                for layer_index in range(len(params[remote_index])):
                    params[remote_index][layer_index].set_(param_agg[layer_index])


In [15]:
# Test function copied from tutorial
# TODO: Inspect and comment this function
def test():
    models[0].eval()
    test_loss = 0
    for data, target in test_loader:
        output = models[0](data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}\n'.format(test_loss))

In [16]:
# Init new timer to measure training time
t = time.time()

for epoch in range(1, args.epochs + 1):
    train()
    test()
    
# Output training time
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Test set: Average loss: 2.3970

Total 23.51 s


In [17]:
train()
test()

Test set: Average loss: 2.3970



In [18]:
# data, target = iris.data, iris.target

# data = torch.from_numpy(data).float()
# target = torch.from_numpy(target).float()

# alices_optimizer.zero_grad()

# pred = alices_model(data)
# loss = F.mse_loss(pred.view(-1), target)
# loss.backward()
# alices_optimizer.step()
# print(loss)
# alice_params = list(alices_model.parameters())
# alice_params[0].data
