Federated Learning is a new revolutionary technique of training deep learning model according to which, models will learn on the client side instead of the server. The intuition behind this technique is that client devices will do local learning on the datasets and send the updated models to the central server in an aggregated form. 

Here I will create 2 virtual workers named "Bob" and "Alice" who will simulate real-world client side devices as far as training data locally on client side devices is concerned. Using the excellent PySyft module, we can simulate that experience.

I will use FederatedDataLoader instead of the standard DataLoader to sort of decentralize the data to both Bob and Alice

Federated learning is used most intensively by Google in their Gboard app which they use for text prediction

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import syft as sy  #import the Pysyft library
hook = sy.TorchHook(torch)  # Hooking PyTorch to PySyft
bob = sy.VirtualWorker(hook, id="bob")  # Creating workers
alice = sy.VirtualWorker(hook, id="alice")  

device = torch.device("cuda" if use_cuda else "cpu")


federated_train_loader = sy.FederatedDataLoader( # using a FederatedDataLoader instead of a normal DataLoader
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,))
                   ]))
    .federate((bob, alice)), #sending dataset too both workers
    batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,))
                   ])),
    batch_size=64, shuffle=True)

      
criterion = nn.CrossEntropyLoss();
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): 
        model.send(data.location) #send the model to both workers as per location
        data = data.view(data.shape[0],-1)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        model.get() 
        if batch_idx % 1000 == 0:
            loss = loss.get() 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * 64, len(train_loader) * 64, \
                100. * batch_idx / len(train_loader), loss.item()))
            
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.view(data.shape[0],-1)
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item() # sum up batch loss
            ps = torch.exp(output) #finding the probablity distribution for an image
            top_p,top_class = ps.topk(1,dim=1) #find the class that the model predicted
            equals = top_class == target.view(*top_class.shape)

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {} ({:.0f}%)\n'.format(
        test_loss, torch.mean(equals.type(torch.FloatTensor)),
        100.*torch.mean(equals.type(torch.FloatTensor))))
    
    
model = nn.Sequential(nn.Linear(784,512),
                     nn.ReLU(),
                     nn.Linear(512,256),
                     nn.ReLU(),
                     nn.Linear(256,10),
                     nn.LogSoftmax(dim=1))
optimizer = optim.SGD(model.parameters(), lr=0.001) 

for epoch in range(1, 11):
    train(model, device, federated_train_loader, optimizer, epoch)
    test(model, device, test_loader)

W0625 19:56:07.748313 140709234329472 hook.py:97] Torch was already hooked... skipping hooking process
  response = command(*args, **kwargs)



Test set: Average loss: 0.0317, Accuracy: 0.6875 (69%)


Test set: Average loss: 0.0228, Accuracy: 0.75 (75%)


Test set: Average loss: 0.0151, Accuracy: 0.8125 (81%)


Test set: Average loss: 0.0113, Accuracy: 0.9375 (94%)


Test set: Average loss: 0.0093, Accuracy: 0.6875 (69%)


Test set: Average loss: 0.0081, Accuracy: 0.6875 (69%)


Test set: Average loss: 0.0073, Accuracy: 0.875 (88%)


Test set: Average loss: 0.0068, Accuracy: 0.9375 (94%)


Test set: Average loss: 0.0064, Accuracy: 0.875 (88%)


Test set: Average loss: 0.0061, Accuracy: 0.8125 (81%)

