In [15]:
import torch
import syft
from torch import nn, optim
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import math

In [2]:
hook = syft.TorchHook(torch)

In [3]:
# Remote workers which train the model
trainer1 = syft.VirtualWorker(hook, id='t1')
trainer2 = syft.VirtualWorker(hook, id='t2')
trainer3 = syft.VirtualWorker(hook, id='t3')

In [46]:
trainer1 = trainer1.clear_objects()
trainer2 = trainer2.clear_objects()
trainer3 = trainer3.clear_objects()

trainer1, trainer2, trainer3

(<VirtualWorker id:t1 #tensors:0>,
 <VirtualWorker id:t2 #tensors:0>,
 <VirtualWorker id:t3 #tensors:0>)

In [47]:
class ToyData(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
        
    def __getitem__(self, index):
        return (self.data[index], self.targets[index])
    
    def __len__(self):
        return len(self.data)

In [48]:
data = torch.tensor([[1.,1,1],
                     [1,1,0],
                     [1,0,1],
                     [1,0,0],
                     [0,1,1],
                     [0,1,0],
                     [0,0,1],
                     [0,0,0]], requires_grad=True)
targets = torch.tensor([[1.],
                        [1],
                        [0],
                        [0],
                        [1],
                        [1],
                        [0],
                        [0]], requires_grad=True)

dataset = ToyData(data, targets).federate((trainer1, trainer2, trainer3))
dataloader = syft.FederatedDataLoader(dataset, batch_size=2)

# Linear model
model = nn.Linear(3,1)

In [49]:
def train(model, dataloader, rounds=10, epochs=20):
    
    opt = optim.SGD(params=model.parameters(), lr=0.1)
    
    for round_iter in range(rounds):
        print('ROUND: ',round_iter)
        for epoch in range(epochs):
            print('\tEPOCH: ',epoch)
            for _data, _target in dataloader:
                model.send(_data.location)
                opt.zero_grad()
                pred = model(_data)
                loss = ((pred - _target)**2).sum()
                loss.backward()
                opt.step()
                model.get()
                print('\t\tLOSS: ',loss.get())
            
    return model

In [50]:
model = train(model,dataloader)
model

ROUND:  0
	EPOCH: 
		LOSS:  tensor(0.0294, requires_grad=True)
		LOSS:  tensor(0.2059, requires_grad=True)
		LOSS:  tensor(0.5298, requires_grad=True)
		LOSS:  tensor(0.1428, requires_grad=True)
		LOSS:  tensor(0.0014, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0.0749, requires_grad=True)
		LOSS:  tensor(0.0485, requires_grad=True)
		LOSS:  tensor(0.2314, requires_grad=True)
		LOSS:  tensor(0.0701, requires_grad=True)
		LOSS:  tensor(0.0014, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0.0272, requires_grad=True)
		LOSS:  tensor(0.0229, requires_grad=True)
		LOSS:  tensor(0.0891, requires_grad=True)
		LOSS:  tensor(0.0318, requires_grad=True)
		LOSS:  tensor(0.0012, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0.0116, requires_grad=True)
		LOSS:  tensor(0.0103, requires_grad=True)
		LOSS:  tensor(0.0345, requires_grad=True)
		LOSS:  tensor(0.0146, requires_grad=True)
		LOSS:  tensor(0.0008, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0.0052, requires_grad=True)
		LOSS:  tensor(0.004

		LOSS:  tensor(1.8539e-10, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(3.0289e-10, requires_grad=True)
		LOSS:  tensor(1.9079e-12, requires_grad=True)
		LOSS:  tensor(1.2786e-10, requires_grad=True)
		LOSS:  tensor(1.2159e-10, requires_grad=True)
		LOSS:  tensor(1.1052e-10, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(1.8466e-10, requires_grad=True)
		LOSS:  tensor(7.4189e-13, requires_grad=True)
		LOSS:  tensor(7.7969e-11, requires_grad=True)
		LOSS:  tensor(7.2649e-11, requires_grad=True)
		LOSS:  tensor(6.5629e-11, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(1.0652e-10, requires_grad=True)
		LOSS:  tensor(6.7435e-13, requires_grad=True)
		LOSS:  tensor(4.5463e-11, requires_grad=True)
		LOSS:  tensor(4.2210e-11, requires_grad=True)
		LOSS:  tensor(3.9051e-11, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(6.6862e-11, requires_grad=True)
		LOSS:  tensor(2.0716e-13, requires_grad=True)
		LOSS:  tensor(2.7804e-11, requires_grad=True)
		LOSS:  tensor(2.5668e-11, requires_grad=True)
		LO

		LOSS:  tensor(4.2529e-18, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.8830e-17, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.9082e-17, requires_grad=True)
		LOSS:  tensor(2.9107e-18, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.1335e-17, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.4046e-17, requires_grad=True)
		LOSS:  tensor(1.9989e-18, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.5767e-17, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.0327e-17, requires_grad=True)
		LOSS:  tensor(1.3773e-18, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.1637e-17, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(7.5853e-18, requires_grad=True)
		LOSS:  tensor(9.5200e-19, requires_grad=True)
		LOSS:  tensor(0., 

		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(4.9978e-22, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(3.0767e-22, requires_grad=True)
		LOSS:  tensor(1.9015e-23, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(3.6347e-22, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.2368e-22, requires_grad=True)
		LOSS:  tensor(1.3764e-23, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.6432e-22, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.6261e-22, requires_grad=True)
		LOSS:  tensor(9.9660e-24, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.9220e-22, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.1821e-22, requires_grad=True)
		LOSS:  tensor(7.2181e-24, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.3975e-22, 

		LOSS:  tensor(7.0570e-27, requires_grad=True)
ROUND:  7
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(4.3276e-27, requires_grad=True)
		LOSS:  tensor(2.5439e-28, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(5.1278e-27, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(3.1446e-27, requires_grad=True)
		LOSS:  tensor(1.8480e-28, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(3.7260e-27, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.2849e-27, requires_grad=True)
		LOSS:  tensor(1.3426e-28, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.7074e-27, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.6602e-27, requires_grad=True)
		LOSS:  tensor(9.7536e-29, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.9673e-27, requires_grad=True)
	EPOCH: 
	

		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(4.3959e-32, requires_grad=True)
		LOSS:  tensor(2.5779e-33, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(5.2095e-32, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(3.1941e-32, requires_grad=True)
		LOSS:  tensor(1.8731e-33, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(3.7852e-32, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.3208e-32, requires_grad=True)
		LOSS:  tensor(1.3610e-33, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(2.7504e-32, requires_grad=True)
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.6863e-32, requires_grad=True)
		LOSS:  tensor(9.8887e-34, requires_grad=True)
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.9984e-32, requires_grad=True)
ROUND:  9
	EPOCH: 
		LOSS:  tensor(0., requires_grad=True)
		LOSS:  tensor(1.

Linear(in_features=3, out_features=1, bias=True)