In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import numpy as np
import PIL

In [2]:
import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: and alice


In [3]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 20
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 10
        self.save_model = False
args = Arguments()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [4]:
classes = {}
pt = 0
for i in pd.read_csv('./csv/train.csv')['label'].unique():
    classes[i] = pt
    pt+=1
def get_onehot(label):
    return classes[label]

In [5]:
class XDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, df, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        label = self.df.iloc[idx]['label']
        img_name = os.path.join(self.root_dir,label,
                                str(self.df.iloc[idx]['image']))
        image = Image.open(img_name)
        image = PIL.ImageOps.grayscale(image)
        onehot = np.array(get_onehot(label))
#         landmarks = landmarks.astype('float').reshape(-1, 2)
        if self.transform:
            image = self.transform(image)
        return (image,onehot)

In [6]:
df_train = pd.read_csv('./csv/train.csv')
df_train = df_train.sample(frac=1)
traindataset = XDataset(df_train,root_dir='./x-ray/train/',
                                        transform=
                                 transforms.Compose([transforms.Resize((28,28)),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.5,), (0.5,))]))

federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    traindataset.federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

df_test = pd.read_csv('./csv/test.csv')
df_test = df_test.sample(frac=1)
testdataset = XDataset(df_test,root_dir='./x-ray/test/',
                                        transform=
                                transforms.Compose([transforms.Resize((28,28)),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.5,), (0.5,))]))

test_loader = torch.utils.data.DataLoader(
    testdataset,
    batch_size=args.test_batch_size, shuffle=True, **kwargs)



In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


In [8]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, #batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [9]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


In [10]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "cnn.pt")


Test set: Average loss: 0.7127, Accuracy: 390/624 (62%)


Test set: Average loss: 0.6914, Accuracy: 390/624 (62%)


Test set: Average loss: 0.6083, Accuracy: 390/624 (62%)


Test set: Average loss: 0.5198, Accuracy: 443/624 (71%)


Test set: Average loss: 0.5130, Accuracy: 454/624 (73%)


Test set: Average loss: 0.4209, Accuracy: 512/624 (82%)


Test set: Average loss: 0.4038, Accuracy: 520/624 (83%)


Test set: Average loss: 0.5547, Accuracy: 474/624 (76%)


Test set: Average loss: 0.4884, Accuracy: 503/624 (81%)


Test set: Average loss: 0.4349, Accuracy: 510/624 (82%)


Test set: Average loss: 0.4054, Accuracy: 515/624 (83%)


Test set: Average loss: 0.4994, Accuracy: 503/624 (81%)


Test set: Average loss: 0.4084, Accuracy: 518/624 (83%)


Test set: Average loss: 0.6047, Accuracy: 481/624 (77%)


Test set: Average loss: 0.6274, Accuracy: 477/624 (76%)


Test set: Average loss: 0.4144, Accuracy: 515/624 (83%)




Test set: Average loss: 0.5015, Accuracy: 495/624 (79%)


Test set: Average loss: 0.4393, Accuracy: 519/624 (83%)


Test set: Average loss: 0.5228, Accuracy: 494/624 (79%)


Test set: Average loss: 0.7032, Accuracy: 475/624 (76%)

