In [1]:
#Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import syft as sy
import sys
import pdb 
import math
import numpy as np
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import datetime
from random import shuffle

In [3]:
use_cuda = True
torch.set_default_tensor_type(torch.cuda.FloatTensor)
kwargs = {'num_workers': 1, 'pin_memory': True}
device = torch.device("cuda")
batch_size = 64
test_batch_size = 1000
image_size = (32,32)
seed = 10
torch.manual_seed(seed)

<torch._C.Generator at 0x20e8780e630>

In [1]:
# Federated Learners

In [10]:
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: and alice
compute_nodes = [alice, bob]




# Loading Train CIFAR 10 dataset

In [6]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

Files already downloaded and verified


# Loading Test CIFAR 10 dataset

In [7]:
# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified


# Neural Network Structure

In [8]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)



# Send dataset to clients

In [12]:
train_distributed_dataset = []
#normal clients
for batch_idx, (data,target) in enumerate(trainloader):
            data_append = data.send(compute_nodes[batch_idx % len(compute_nodes)], inplace = True)
            target_append = target.send(compute_nodes[batch_idx % len(compute_nodes)], inplace = True)
            train_distributed_dataset.append((data_append, target_append))

#shuffle list
shuffle(train_distributed_dataset)

#train_distributed_dataset[1] - to check that it's shuffled

(Tensor>[PointerTensor | me:6823191096 -> alice:18083011846],
 Tensor>[PointerTensor | me:67442683695 -> alice:63166093778])

# Training Function

In [16]:
def train(epoch, device, trainloader):
    for batch_idx, (data,target) in enumerate(trainloader):        
        model.send(data.location) # 0) send the model to the right location
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad() # 1) erase previous gradients (if they exist)

        output = model(data)  # 2) make a prediction
        loss = F.nll_loss(output, target) # 3) calculate how much we missed
        loss.backward() # 4) figure out which weights caused us to miss
        optimizer.step() # 5) change those weights
        model.get() # 6) get model (with gradients)

# Test Function

In [18]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

# Run everyting

In [19]:
model = Net().to(device)

optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(1, 51):
    print(datetime.datetime.now())
    print("epoch: " + str(epoch))
    train(epoch, device, train_distributed_dataset)
    test(model, device, testloader)
#save model
torch.save(model.state_dict(), ("cifar_model.pt"))



2019-06-18 17:18:56.899167
epoch: 1
Test set: Average loss: 1.9996, Accuracy: 2829/10000 (28%)

2019-06-18 17:19:46.818082
epoch: 2
Test set: Average loss: 1.7568, Accuracy: 3690/10000 (37%)

2019-06-18 17:20:36.223179
epoch: 3
Test set: Average loss: 1.5920, Accuracy: 4182/10000 (42%)

2019-06-18 17:21:25.774209
epoch: 4
Test set: Average loss: 1.4877, Accuracy: 4619/10000 (46%)

2019-06-18 17:22:15.320244
epoch: 5
Test set: Average loss: 1.4223, Accuracy: 4863/10000 (49%)

2019-06-18 17:23:04.653369
epoch: 6
Test set: Average loss: 1.3801, Accuracy: 5042/10000 (50%)

2019-06-18 17:23:53.974455
epoch: 7
Test set: Average loss: 1.3476, Accuracy: 5129/10000 (51%)

2019-06-18 17:24:43.710429
epoch: 8
Test set: Average loss: 1.3160, Accuracy: 5269/10000 (53%)

2019-06-18 17:25:32.780617
epoch: 9
Test set: Average loss: 1.2881, Accuracy: 5383/10000 (54%)

2019-06-18 17:26:21.937778
epoch: 10
Test set: Average loss: 1.2720, Accuracy: 5433/10000 (54%)

2019-06-18 17:27:11.218915
epoch: 11
Te