Based on PySyft's tutorial https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2007%20-%20Federated%20Learning%20with%20Federated%20Dataset.ipynb

In [None]:
! pip install syft
! pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

**RESTART RUNTIME**

In [1]:
import torch as th
import syft as sy
sy.create_sandbox(globals(), verbose=False)

Setting up Sandbox...
Done!


In [2]:
boston_data = grid.search("#boston", "#data")
boston_target = grid.search("#boston", "#target")

In [3]:
n_features = boston_data['alice'][0].shape[1]
n_targets = 1

model = th.nn.Linear(n_features, n_targets)

In [4]:
datasets = []
for worker in boston_data.keys():
    dataset = sy.BaseDataset(boston_data[worker][0], boston_target[worker][0])
    datasets.append(dataset)

# Build the FederatedDataset object
dataset = sy.FederatedDataset(datasets)
print(dataset.workers)
optimizers = {}
for worker in dataset.workers:
    optimizers[worker] = th.optim.Adam(params=model.parameters(),lr=1e-2)

['bob', 'theo', 'jason', 'alice', 'andy', 'jon']


In [5]:
train_loader = sy.FederatedDataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)


In [7]:
epochs = 10
for epoch in range(1, epochs + 1):
    loss_accum = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        model.send(data.location)
        
        optimizer = optimizers[data.location.id]
        optimizer.zero_grad()
        pred = model(data)
        loss = ((pred.view(-1) - target)**2).mean()
        loss.backward()
        optimizer.step()
        
        model.get()
        loss = loss.get()
        
        loss_accum += float(loss)
        
        if batch_idx % 8 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tBatch loss: {:.6f}'.format(
                epoch, batch_idx, len(train_loader),
                       100. * batch_idx / len(train_loader), loss.item()))            
            
    print('Total loss', loss_accum)

Total loss 103776.92793273926
Total loss 37816.44436645508
Total loss 26556.274700164795
Total loss 11086.372959136963
Total loss 12775.664268493652
Total loss 2969.2659454345703
Total loss 4095.403242111206
Total loss 4586.587760925293
Total loss 2661.6517086029053
Total loss 2563.6421031951904


Let's redesign it with PyTorch-Lightning!

In [45]:
from pytorch_lightning.core.lightning import LightningModule
import torch.nn.functional as F

class LNet(LightningModule):
    def __init__(self):
        super(LNet, self).__init__()
        #self.conv1 = nn.Conv2d(1, 20, 5, 1)
        #self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = th.nn.Linear(n_features, n_targets)
        #self.fc2 = nn.Linear(100, n_targets)

    def forward(self, x):
        #x = F.relu(self.conv1(x))
        #x = F.max_pool2d(x, 2, 2)
        #x = F.relu(self.conv2(x))
        #x = F.max_pool2d(x, 2, 2)
        #x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        #x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        data, target = batch
        self.send(data.location)
        pred = self.forward(data)
        loss = ((pred.view(-1) - target)**2).mean()
        self.get()
        return {'loss': loss.get() }

    def train_dataloader(self):
       return sy.FederatedDataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
    
    def configure_optimizers(self):
        # this is the part where I am confused!
        # original: optimizer = optimizers[data.location.id]
        # we cannot access the data in the moment here
        # we do not get the appropriate optimizer in training_step()
        # yet it is still working, I think?
        optimizer = th.optim.Adam(params=self.parameters(),lr=1e-2)
        return optimizer

In [47]:
from pytorch_lightning import Trainer

use_cuda = False
device = th.device("cuda" if use_cuda else "cpu")
net = LNet().to(device)
trainer = Trainer(gpus=0, max_epochs=10)
trainer.fit(net)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | fc1  | Linear | 14    


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [48]:
dataset.workers

['bob', 'theo', 'jason', 'alice', 'andy', 'jon']