<a href="https://colab.research.google.com/github/ebagdasa/federated_homework/blob/master/fl_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Federated Learning with PyTorch



Your task is to train a model that recognizes [MNIST digits](https://en.wikipedia.org/wiki/MNIST_database) using Federated Learning that creates a joint global model from multiple local models trained on user data. Please consult the [original paper](https://arxiv.org/abs/1602.05629) and don't hesitate to ask questions.

We provide some skeleton code and a reference implementation of the centralized training from [PyTorch examples](https://github.com/pytorch/examples/blob/master/mnist/main.py).

**Your task:** fill up the skeleton code and write a training procedure for FL. 

Submit this notebook to Canvas and we will run it and examine the results.


We slightly modify the main algorithm at the paper in the following way:

- We compute a local update, i.e. difference between local and global updates: L-G<sup>t</sup>.
- Accumilated sum of local updates is divided by a number of users in this round `round_size`.
- The global model at next round G<sup>t+1</sup>= G<sup>t</sup> + `global_lr` * Sum(L<sub>i</sub>-G<sup>t</sup>)/`round_size`

## Helper functions and parameters

No need to modify this part but you need to be familiar with the primitives: data loaders, models, central training and testing functions. It's mainly taken from [PyTorch examples](https://github.com/pytorch/examples/blob/main/mnist/main.py).

In [18]:
from __future__ import print_function
import argparse
import random
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import Parameter
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.optim import SGD
import numpy as np

from torch.utils.data import DataLoader, SubsetRandomSampler

from tqdm.notebook import tqdm
from typing import List, Dict, Union

In [2]:
# use cpu or cuda
use_cuda=True

# Learning rate for the update of global model
global_lr = 0.1
local_lr = 0.1

batch_size = 32


# FL parameters
no_users = 100
no_rounds = 5
round_size = 10



### Load data

In [3]:
train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': batch_size}
if use_cuda:
    device='cuda'
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   }
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
else:
    device='cpu'

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
test_dataset = datasets.MNIST('./data', train=False, download=True,
                   transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 165459662.54it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 96599436.86it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 36532483.55it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6564620.53it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



### Split training dataset into smaller ones for each participant



In [4]:
def split_data(training_dataset, no_users):
    data_per_user = len(training_dataset) // no_users
    data_loaders = list()
    for i in range(no_users):
        indices = random.sample(list(range(len(training_dataset))), data_per_user)
        sampler = torch.utils.data.SubsetRandomSampler(indices)
        data_loader = torch.utils.data.DataLoader(training_dataset, sampler=sampler, **train_kwargs)
        data_loaders.append(data_loader)
        
    return data_loaders
    

In [5]:
train_dataset = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)

data_loaders = split_data(train_dataset, no_users)

### Define model

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output



### Testing of the global model

In [7]:
def test(model, device, test_loader):
    """Perform testing of the global aggregated model.
    
    Args:
        model: torch.nn global model.
        train_loader: loader for global testing data.
    
    Returns:
        None.
    """
        
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    


### Example of a central training function

In [8]:
def train(epoch: int, model: Net, train_loader: DataLoader, optimizer: optim.Optimizer):
    """ Centralized training.
    
    Args:
        epoch: training epoch.
        model: torch.nn model.
        train_loader: loader for global training data.
        optimizer: optimizer for global model.
    
    Returns:
        None
    """
    
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

### Here are some ideas to help you navigate PyTorch training for FL:

1. No need to do parallel training of the models in one round. You can sequentially train each model in a round and accumulate the model weights of trained models.

2. Once the result in a round is summed up you can average it and apply to the global model.

3. Be careful on copying tensors: participant's model before training needs to be a copy of the global model. It's useful to always keep `global_model` and create a `local_model` for every participant (use `deepcopy()`). The optimizer needs to be recreated for every `local_model`.

**Primitives**:

`model.state_dict()` -- returns a dictionary of layer names and parameters (weights) for the model. 

`param.data.detach()` -- useful for computing a difference between the global and local models. 

`param.data.add_(data)` -- modify the value of the weight tensor by adding `data` (useful when updating `global_model`).

`copy.deepcopy(global_model)` -- copy the whole model (useful to create a `local_model`).

`optimizer = optim.SGD(local_model.parameters(), lr=local_lr)` -- create optimizer for local model. 


# Task begins

## 1. Fill FL primitives

Create a `local_train` function to train the model locally, use `accumulate` to sum local models into one object, and `average` to average these models and update the global model.

In [12]:
def local_train(model_id: int, global_model: Net, train_loader: DataLoader, 
                local_lr: float, clipping_norm: float=None) -> Dict[str, torch.Tensor]:
    """Perform training of the local model on local data.
    
    Args:
        model_id: identificator of the local model.
        global_model: global model (cannot be modified!).
        train_loader: loader for local training data.
        local_lr: learning rate for the local optimizer.
        clipping_norm: bound of model update for DP training.
    
    Returns:
        Model update, i.e. for each param local_model - global_model.
    """
    
    # YOUR CODE GOES HERE (Hint: modify centralized training function)


In [13]:
def accumulate(local_update: Dict[str, torch.Tensor], 
               weight_aggregator: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Accumulate local updates into a weight_aggregator.
    
    Args:
        local_update: dictionary with updates. 
        weight_aggregator: sum of all local models from the single round.
    
    Returns: 
        Updated weight aggregator.
    """
    
    # YOUR CODE GOES HERE

In [14]:
def average(no_models: int, global_model: Net, weight_aggregator: Dict[str, torch.Tensor], 
            global_lr: float, noise_std: float=None) -> Net:
    """Average accumulated models and apply them to the global model.
    
    Args:
        global_model: Server's FL model
        no_models: number of models in a single FL round.
        weight_aggregator: sum of all local models from the single round.
        global_lr: learning rate to update the global model.
        noise_std: added noise for DP training.
    
    Returns:
        Updated global model.
    """
    
    # YOUR CODE GOES HERE

## 2. Run FL training and testing


Using above primitives implement Federated Learning routine by training 
a global model for `no_rounds` and sampling `round_size` users from `no_user` for each round.

Don't forget to test the `global_model` on convergence. Successfully trained global model will score above 75% on the test set.

In [20]:
global_model = Net().to(device)
weight_aggregator = dict()

In [None]:
# Training code goes here!

## 3. Implement Differentially Private FL

It's possible to protect user model updates by applying Differential Privacy (can check this [Google paper](https://arxiv.org/abs/1710.06963), Algorithm 1). Augment above FL code to support DP training by clipping each layer of the local model update by value **S** to limit update sensitivity and add Gaussian noise with std **sigma** when updating the global model (i.e. to the averaged update before scaling by global_lr).

Modify **`local_training`** method to use `clipping_norm` and an **`aggregate`** method to add Gaussian noise with `noise_std`.


In [None]:
# Your model accuracy should be higher than 60%

S = 1
sigma = 0.1

In [None]:
# DP Training code goes here!