<a href="https://colab.research.google.com/github/ewotawa/secure_private_ai/blob/master/Section_2_Federated_Learning_Final_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Federated Learning Final Project

## Overview
* See  <a href="https://classroom.udacity.com/nanodegrees/nd185/parts/3fe1bb10-68d7-4d84-9c99-9539dedffad5/modules/28d685f0-0cb1-4f94-a8ea-2e16614ab421/lessons/c8fe481d-81ea-41be-8206-06d2deeb8575/concepts/a5fb4b4c-e38a-48de-b2a7-4e853c62acbe">video</a> for additional details. 
* Do Federated Learning where the central server is not trusted with the raw gradients.  
* In the final project notebook, you'll receive a dataset.  
* Train on the dataset using Federated Learning.  
* The gradients should not come up to the server in raw form.  
* Instead, use the new .move() command to move all of the gradients to one of the workers, sum them up there, and then bring that batch up to the central server and then bring that batch up 
* Idea: the central server never actually sees the raw gradient for any person.  
* We'll look at secure aggregation in course 3.  
* For now, do a larger-scale Federated Learning case where you handle the gradients in a special way.

## References
*  <a href = "https://blog.openmined.org/upgrade-to-federated-learning-in-10-lines/">DEEP LEARNING -> FEDERATED LEARNING IN 10 LINES OF PYTORCH + PYSYFT</a>
* <a href ="https://github.com/udacity/private-ai/pull/10">added data for Federated Learning project</a>
* <a href="https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%206%20-%20Federated%20Learning%20on%20MNIST%20using%20a%20CNN.ipynb">Part 6 - Federated Learning on MNIST using a CNN.ipynb</a>
* <a href="https://docs.google.com/spreadsheets/d/1x-QQK-3Wn86bvSbNTf2_p2FXVCqiic2QwjcArQEuQlg/edit#gid=0">Slack Channel's reference sheet </a>
* <a href="https://github.com/ucalyptus/Federated-Learning/blob/master/Federated%20Learning.ipynb">Federated Learning Example from Slack Channel reference sheet</a>

### Install libraries and dependencies

In [2]:
!pip install syft

import syft as sy

Collecting syft
[?25l  Downloading https://files.pythonhosted.org/packages/38/2e/16bdefc78eb089e1efa9704c33b8f76f035a30dc935bedd7cbb22f6dabaa/syft-0.1.21a1-py3-none-any.whl (219kB)
[K     |████████████████████████████████| 225kB 2.8MB/s 
Collecting zstd>=1.4.0.0 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/8e/27/1ea8086d37424e83ab692015cc8dd7d5e37cf791e339633a40dc828dfb74/zstd-1.4.0.0.tar.gz (450kB)
[K     |████████████████████████████████| 450kB 39.8MB/s 
[?25hCollecting flask-socketio>=3.3.2 (from syft)
  Downloading https://files.pythonhosted.org/packages/4b/68/fe4806d3a0a5909d274367eb9b3b87262906c1515024f46c2443a36a0c82/Flask_SocketIO-4.1.0-py2.py3-none-any.whl
Collecting websocket-client>=0.56.0 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/29/19/44753eab1fdb50770ac69605527e8859468f3c0fd7dc5a76dd9c4dbd7906/websocket_client-0.56.0-py2.py3-none-any.whl (200kB)
[K     |████████████████████████████████| 204kB 42.6MB/s 
Collecti

W0717 05:07:52.751605 140071711582080 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/usr/local/lib/python3.6/dist-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0717 05:07:52.773896 140071711582080 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [3]:
!pip install torch
!pip install torchvision

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms



In [0]:
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")

In [0]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 10
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [6]:
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

  0%|          | 16384/9912422 [00:00<01:29, 110076.91it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:00, 28436630.17it/s]                          


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz


32768it [00:00, 451125.37it/s]
  1%|          | 16384/1648877 [00:00<00:11, 144335.60it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 7000585.16it/s]                           
8192it [00:00, 184421.42it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


W0717 05:09:25.788734 140071711582080 dataloader.py:197] The following options are not supported: num_workers: 1, pin_memory: True


In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [0]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        model.send(data.location) # <-- NEW: send the model to the right location
        # data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, #batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [0]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            # data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [14]:
# model = Net().to(device)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "mnist_cnn.pt")



Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 234, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <class 'syft.frameworks.torch.tensors.interpreters.native.Tensor'>: attribute lookup Tensor on syft.frameworks.torch.tensors.interpreters.native failed
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 234, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <class 'syft.frameworks.torch.tensors.interpreters.native.Tensor'>: attribute lookup Tensor on syft.frameworks.torch.tensors.interpreters.native failed


KeyboardInterrupt: ignored