<a href="https://colab.research.google.com/github/hibatallahk/My-Federated-Learning-implementations/blob/main/First_FL_PySyft_2_clients_MNIST_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



## Federated Learning on MNIST using a CNN with PyTorch & PySyft


###related reference: 
PyTorch(https://github.com/pytorch/examples/blob/master/mnist/main.py)

PySyft (https://github.com/OpenMined/PySyft/)

- Credits goes to Archit Garg for the tutorial in school of AI/UDACITY

- Install syft package by pip.

In [None]:
!pip3 install syft==0.2.9



### Imports and model specifications

In [None]:
!pip3 install torchvision

Collecting torch==1.11.0
  Downloading torch-1.11.0-cp37-cp37m-manylinux1_x86_64.whl (750.6 MB)
[K     |████████████████████████████████| 750.6 MB 9.2 kB/s 
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 1.10.0
    Uninstalling torch-1.10.0:
      Successfully uninstalled torch-1.10.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
syft 0.6.0 requires torch<=1.10.0,>=1.8.1, but you have torch 1.11.0 which is incompatible.[0m
Successfully installed torch-1.11.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

- Define remote workers `i_1` and `i_2`.

In [None]:
import syft as sy  # import the Pysyft library
hook = sy.TorchHook(torch)  # hook PyTorch, extra functionalities to support Federated Learning
i_1 = sy.VirtualWorker(hook, id="i_1")  # define remote worker i_1
i_2 = sy.VirtualWorker(hook, id="i_2")  # define remote worker i_2

We define the setting of the learning task

In [None]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

### Data loading and sending to workers
- First loading the data and transforming the training Dataset into a Federated Dataset split across the workers using the `.federate` method. 
- This federated dataset is now given to a Federated DataLoader. The test dataset remains unchanged.

In [None]:
federated_train_loader = sy.FederatedDataLoader( # this is a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((i_1, i_2)), # we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


### CNN specification
- We use a standard CNN

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

### Define the train and test functions
For the train function, because the data batches are distributed across `alice` and `bob`, you need to send the model to the right location for each batch. Then, you perform all the operations remotely with the same syntax like you're doing local PyTorch. When you're done, you get back the model updated and the loss to look for improvement.

In [None]:
def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # now it is a distributed dataset
        model.send(data.location) # send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))

The test function

In [None]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

### Launch the training

In [None]:
%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "mnist_cnn.pt")




Test set: Average loss: 0.1574, Accuracy: 9512/10000 (95%)


Test set: Average loss: 0.0901, Accuracy: 9735/10000 (97%)


Test set: Average loss: 0.0738, Accuracy: 9756/10000 (98%)


Test set: Average loss: 0.0547, Accuracy: 9811/10000 (98%)


Test set: Average loss: 0.0461, Accuracy: 9849/10000 (98%)


Test set: Average loss: 0.0443, Accuracy: 9858/10000 (99%)


Test set: Average loss: 0.0444, Accuracy: 9862/10000 (99%)


Test set: Average loss: 0.0363, Accuracy: 9884/10000 (99%)


Test set: Average loss: 0.0346, Accuracy: 9889/10000 (99%)


Test set: Average loss: 0.0378, Accuracy: 9873/10000 (99%)

CPU times: user 27min 20s, sys: 1min 22s, total: 28min 43s
Wall time: 28min 46s


- We trained a model on remote data on 2 clients (`i_1` and `i_2`)using Federated Learning