In [1]:
import torch
from torch import nn, optim
from torchvision import models
import syft
import numpy as np
from matplotlib import pyplot as plt



## Toy Data

In [2]:
hook = syft.TorchHook(torch)

In [3]:
# Remote workers which train the model
trainer1 = syft.VirtualWorker(hook, id='t1')
trainer2 = syft.VirtualWorker(hook, id='t2')
aggregator = syft.VirtualWorker(hook, id='agg')

In [44]:
trainer1 = trainer1.clear_objects()
trainer2 = trainer2.clear_objects()
aggregator = aggregator.clear_objects()

trainer1, trainer2, aggregator

(<VirtualWorker id:t1 #tensors:0>,
 <VirtualWorker id:t2 #tensors:0>,
 <VirtualWorker id:agg #tensors:0>)

In [45]:
# OR Gate
data = torch.tensor([[1.,1],
                     [1,0],
                     [0,1],
                     [0,0]], requires_grad=True)
targets = torch.tensor([[1.],
                       [1],
                       [1],
                       [0]], requires_grad=True)
# Linear model
model = nn.Linear(2,1)

[tensor([[ 0.1713, -0.2617]]), tensor([-0.6135])]

In [46]:
# Send data to trainers
data1 = data[0:2].send(trainer1)
data2 = data[2:4].send(trainer2)

# Send targets to trainers
targets1 = targets[0:2].send(trainer1)
targets2 = targets[2:4].send(trainer2)

In [47]:
def train(model, dataloader, rounds=10, epochs=20):

    for round_iter in range(rounds):
        
        _models = []
        _opts = []
        
        # Send model to workers
        for remote_data in dataloader:
            _models.append(model.copy().send(remote_data[0].location))
            # SGD optimizer
            _opts.append(optim.SGD(params=_models[-1].parameters(), lr=0.1))

        # Training loop
        for epoch in range(epochs):
            for i in range(len(_models)):
                _opts[i].zero_grad()
                pred = _models[i](dataloader[i][0])
                loss = ((pred - dataloader[i][1])**2).sum()
                loss.backward()
                _opts[i].step()
        for _model in _models:
            _model.move(aggregator)
        
        model.weight.data.set_((sum([_model.weight.data for _model in _models]) / 2).get())
        model.bias.data.set_((sum([_model.bias.data for _model in _models]) / 2).get())
            
    return model

In [48]:
dataloader = [(data1, targets1), (data2, targets2)]
model = train(model, dataloader)
model

Linear(in_features=2, out_features=1, bias=True)