We’ll implement the federated learning approach to train a simple neural network on the MNIST dataset using the two workers: Jake and John.

### 1. Import the libraries and modules

In [1]:
import torch
import syft as sy
import torchvision
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms



#### Initializing the hook

In [2]:
hook = sy.TorchHook(torch)

This is done to override PyTorch’s methods to execute commands on one worker that are called on tensors controlled by the local worker. It also allows us to move tensors between workers.

### 2. Create virtual workers

In [3]:
jake = sy.VirtualWorker(hook, id="jake")
john = sy.VirtualWorker(hook, id="john")

Virtual workers are entities present on our local machine. They are used to model the behavior of actual workers.

### 3. Load the dataset

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )),
])

train_set = datasets.MNIST(
    "~/.pytorch/MNIST_data/", train=True, download=True, transform=transform)
test_set = datasets.MNIST(
    "~/.pytorch/MNIST_data/", train=False, download=True, transform=transform)

In real-life applications, the data is present on client devices. To replicate the scenario, we send data to the VirtualWorkers.

In [5]:
federated_train_loader = sy.FederatedDataLoader(
    train_set.federate((jake, john)), batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=64, shuffle=True)

Notice that we have created the training dataset differently. The `train_set.federate((jake, john))` creates a _FederatedDataset_ wherein the train_set is split among Jake and John (our two VirtualWorkers). The _FederatedDataset_ class is intended to be used like the PyTorch’s _Dataset_ class. Pass the created _FederatedDataset_ to a federated data loader _FederatedDataLoader_ to iterate over it in a federated manner. The batches then come from different devices.

### 4. Build the model

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Model()
optimizer = optim.SGD(model.parameters(), lr=0.01)

### 5. Train the model

Since the data is present on the client device, we obtain its location through the location attribute. The important additions to the code are the steps to get back the improved model and the value of the loss from the client devices.

In [7]:
for epoch in range(0, 10):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader):
        
        # send the model to the client device where the data is present
        model.send(data.location)
        
        # training the model
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        # get back the improved model
        model.get()
        
        if batch_idx % 100 == 0:
            
            # get back the loss
            loss = loss.get()
            
            print('Epoch: {:2d} [{:5d}/{:5d} ({:3.0f}%)]\tLoss: {:.6f}'.format(
                epoch + 1, batch_idx * 64,
                len(federated_train_loader) * 64,
                100. * batch_idx / len(federated_train_loader), loss.item()))



 ### 6. Test the model

In [8]:
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        
        output = model(data)
        test_loss += F.nll_loss(output, target, reduction='sum').item()
        
        # get the index of the max log-probability
        pred = output.argmax(1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

Test set: Average loss: 0.1781, Accuracy: 9484/10000 (95%)


In [9]:
label_names = [str(x) for x in range(0, 10)]
label_names

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

In [10]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        
        for i in range(10):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

In [11]:
print("Test Accuracies")
for i in range(10):
    print('%5s : %2d %%' % (label_names[i],
                            100 * class_correct[i] / class_total[i]))

Test Accuracies
    0 : 99 %
    1 : 99 %
    2 : 94 %
    3 : 94 %
    4 : 90 %
    5 : 96 %
    6 : 97 %
    7 : 94 %
    8 : 94 %
    9 : 94 %


That’s it. We have trained a model using the federated learning approach. When compared to traditional training, it takes more time to train a model using the federated approach.