In [1]:
import torch
from torch import nn, optim
from torchvision import models
import syft
import numpy as np
from matplotlib import pyplot as plt



## Toy Data

In [2]:
hook = syft.TorchHook(torch)

In [3]:
# Remote workers which train the model
trainer1 = syft.VirtualWorker(hook, id='t1')
trainer1 = trainer1.clear_objects()
trainer2 = syft.VirtualWorker(hook, id='t2')
trainer2 = trainer2.clear_objects()

trainer1, trainer2

(<VirtualWorker id:t1 #tensors:0>, <VirtualWorker id:t2 #tensors:0>)

In [4]:
# OR Gate
data = torch.tensor([[1.,1],
                     [1,0],
                     [0,1],
                     [0,0]], requires_grad=True)
targets = torch.tensor([[1.],
                       [1],
                       [1],
                       [0]], requires_grad=True)
# Linear model
model_original = nn.Linear(2,1)
model = model_original

In [5]:
# Send data to trainers
data1 = data[0:2].send(trainer1)
data2 = data[2:4].send(trainer2)

# Send targets to trainers
targets1 = targets[0:2].send(trainer1)
targets2 = targets[2:4].send(trainer2)

In [6]:
def train(_model, _data, _targets, epochs=20):
    
    # SGD optimizer
    opt_sgd = optim.SGD(params=model.parameters(), lr=0.1)
    
    # Training loop
    for epoch in range(epochs):
            opt_sgd.zero_grad()
            pred = _model(_data)
            loss = ((pred - _targets)**2).sum()
            loss.backward()
            opt_sgd.step()
            
    return _model

In [7]:
# trainers all have data
# send model to first trainer
model = model.send(trainer1)
print(model, model._objects)
# train
model = train(model, data1, targets1)
print(model, model._objects)
# move model to next trainer
model.move(trainer2)
# train
model = train(model, data2, targets2)
print(model, model._objects)
# get model
model = model.get()

AttributeError: 'Linear' object has no attribute '_objects'

In [None]:
model_original.state_dict(), model.state_dict()

In [9]:
trainer1._objects

{55204747356: tensor([[1., 1.],
         [1., 0.]], requires_grad=True), 45916076843: tensor([[1.],
         [1.]], requires_grad=True), 53016398024: Parameter containing:
 tensor([[-0.4599, -0.4819]], requires_grad=True), 394810585: Parameter containing:
 tensor([0.5324], requires_grad=True)}