In [6]:
# Resources and tutorials:
# OpenMined Advanced tutorials: https://github.com/OpenMined/PySyft/tree/master/examples/tutorials/advanced/websockets-example-MNIST
# Andrew Task youtube lessons: https://www.youtube.com/watch?v=TWa6wFarCeI
# OpenMined Blog about FL: https://blog.openmined.org/upgrade-to-federated-learning-in-10-lines/
# OpenMined blog about setting FL and RNN with RPi: https://blog.openmined.org/federated-learning-of-a-rnn-on-raspberry-pis/
# Udacity and Facebook "Secure and Private AI challenge"

In [1]:
# %load_ext autoreload

# %autoreload 2

# load libraries

import sys
import syft as sy
from syft.workers.virtual import VirtualWorker
from syft.workers import WebsocketClientWorker
from syft import FederatedDataset, FederatedDataLoader, BaseDataset
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms, models, utils
from syft.frameworks.torch.federated import utils

W0820 20:21:19.138794 4511266240 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/Users/jluissamper/.virtualenvs/pytorch/lib/python3.6/site-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0820 20:21:19.155421 4511266240 deprecation_wrapper.py:119] From /Users/jluissamper/.virtualenvs/pytorch/lib/python3.6/site-packages/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [2]:
# import model to share and other client nn-related functionalities such as: next batch, train, get params...
import run_websocket_client as rwc

In [3]:
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=64, cuda=False, epochs=2, federate_after_n_batches=50, lr=0.01, save_model=False, seed=1, test_batch_size=1000, use_virtual=False, verbose=False)


In [4]:
hook = sy.TorchHook(torch)

In [5]:
# websocket clients and workers instantiation. This step will fall if the websocket server workers are not running

kwargs_websocket = {"host": "127.0.0.1", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

workers = [alice, bob, charlie]
print(workers)

[<WebsocketClientWorker id:alice #objects local:0 #objects remote: 4>, <WebsocketClientWorker id:bob #objects local:0 #objects remote: 4>, <WebsocketClientWorker id:charlie #objects local:0 #objects remote: 4>]


# Prepare and distribute the training data

In [None]:
# number of subprocesses to use for data loading
num_workers = 4
# how many samples per batch to load
batch_size = 1
# Images size to rescale
img_size = (512,512)
# percentage of training set to use as validation
valid_size = 0.2

data_dir = '~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/Data 15/'
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [None]:
eye_dataset = simpleImageLoader(csv_file='~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/labels/trainLabels15.csv',
                                    root_dir='~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/Data 15/train 15',
                                    transform = transform)  

In [None]:
eye_dataloader = torch.utils.data.DataLoader(eye_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=num_workers)

In [5]:
#run this box only if the the next box gives pipeline error
torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=True,download=True))



<torch.utils.data.dataloader.DataLoader at 0x1497fc780>

In [6]:
# Download the MNIST dataset and use federated dataloader

federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ).federate(tuple(workers)),
    batch_size=args.batch_size,
    shuffle=True,
    iter_per_worker=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True
)



In [7]:
# instantiate the model, imported from run_websocket_client.py
# it is a 2 layers conv net
model = rwc.Net().to(device)
print(model)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [8]:
import logging
import sys
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s")
handler.setFormatter(formatter)
logger.handlers = [handler]

In [10]:
# start the training

for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = rwc.train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
    rwc.test(model, device, test_loader)

Starting epoch 1/2


2019-08-19 11:55:14,364 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [0, 50]
2019-08-19 11:55:50,703 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [50, 100]
2019-08-19 11:56:27,664 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [100, 150]
2019-08-19 11:57:06,275 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [150, 200]
2019-08-19 11:57:45,072 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [200, 250]
2019-08-19 11:58:26,354 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [250, 300]
2019-08-19 11:58:52,821 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [300, 350]
2019-08-19 11:59:05,799 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [350, 400]
2019-08-19 11:59:05,810 DEBUG run_websocket_client.py(l:136) - At least one worker ran out of data, stopping.
2019-08-19 11:59:10,593 DEBUG run_webs

Starting epoch 2/2


2019-08-19 11:59:25,570 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [0, 50]
2019-08-19 12:00:01,800 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [50, 100]
2019-08-19 12:00:38,582 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [100, 150]
2019-08-19 12:01:15,119 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [150, 200]
2019-08-19 12:01:50,838 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [200, 250]
2019-08-19 12:02:27,423 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [250, 300]
2019-08-19 12:02:52,614 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [300, 350]
2019-08-19 12:03:05,023 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [350, 400]
2019-08-19 12:03:05,033 DEBUG run_websocket_client.py(l:136) - At least one worker ran out of data, stopping.
2019-08-19 12:03:09,372 DEBUG run_webs