This notebook demonstrates collective learning with PySyft.

First step is to import some modules and hook PySyft onto torch:

In [1]:
from random import randint
import torch
import torch.nn as nn
import torch.nn.functional as nn_func
import torch.optim as optim
from torchvision import datasets, transforms
import syft as sy


hook = sy.TorchHook(torch)

Then we define some arguments for the model and training:

In [2]:
class Arguments:
    def __init__(self):
        self.batch_size = 64
        self.n_batches_for_vote = 5
        self.test_batch_size = 1000
        self.epochs = 5
        self.lr = 0.01
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.n_hospitals = 5
        self.vote_threshold = (self.n_hospitals - 1) // 2


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = nn_func.relu(self.conv1(x))
        x = nn_func.max_pool2d(x, 2, 2)
        x = nn_func.relu(self.conv2(x))
        x = nn_func.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = nn_func.relu(self.fc1(x))
        x = self.fc2(x)
        return nn_func.log_softmax(x, dim=1)


This is the interesting bit: collective learning in PySyft. The function below performs one round of collective learning. First the original weights of the model are saved. Then a random worker is selected to perform training. After the training the model is sent to the other workers. Each worker evaluates the loss on a set number of batches of the training set. If the loss is lower for the new weights then that worker has a positive vote, otherwise it's a negative vote. The votes are summed up and if they are over the voting threshold then the new weights are accepted. If the positive votes do not pass the threshold then the weights are replaced by the saved weights from the beginning of the round.

In [4]:
LOSSES = {}  # dictionary of loss for each worker

def colearn_train(args, model: Net, device,
                  federated_train_loader: sy.FederatedDataLoader,
                  optimizer, epoch, workers):
    global LOSSES
    model.train()  # sets model to "training" mode. Does not perform training.
    # need to save the state dict of the old model
    state_dict = model.state_dict()
    # pick a random hospital
    proposer_index = randint(0, len(workers) - 1)
    proposer = workers[proposer_index]
    print("Proposer", proposer_index, proposer)
    model.send(proposer)

    # go through all the batches for hosp_n, perform training, get model back
    for batch_idx, data_dict in enumerate(federated_train_loader):  # a distributed dataset
        data, target = data_dict[proposer]
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn_func.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            loss = loss.get()  # get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size,
                len(federated_train_loader) * args.batch_size,
                100. * batch_idx * len(workers) / len(federated_train_loader), loss.item()))

    # send to others and vote
    model.get()
    losses = test_on_training_set(args, model, device, federated_train_loader, workers)
    vote = False
    print("Doing voting")
    if not LOSSES:
        vote = True
        print("First round, no voting")
        LOSSES = losses
    else:
        votes = 0
        for worker, old_loss in LOSSES.items():
            if worker != proposer:
                if old_loss > losses[worker]:
                    votes += 1
                    print(worker, "votes yes")
                else:
                    print(worker, "votes no")
        if votes >= args.vote_threshold:
            print("Vote succeeded")
            vote = True
            LOSSES = losses
    if not vote:
        print("Vote failed")
        # then load the old weights into the model
        model.load_state_dict(state_dict)


This function evaluates the loss for each worker on n random batches from the training set:

In [5]:
def test_on_training_set(args: Arguments, model, device, train_loader, workers):
    model.eval()  # sets model to "eval" mode.
    losses = {w: 0 for w in workers}
    batch_count = 0
    with torch.no_grad():
        for data_dict in train_loader:
            for worker, (data, target) in data_dict.items():
                model.send(data.location)
                data, target = data.to(device), target.to(device)
                output = model(data)
                losses[worker] += nn_func.nll_loss(output, target, reduction='sum').get()  # sum up batch loss
                model.get()
            batch_count += 1
            if batch_count == args.n_batches_for_vote:
                break
    return losses


We need to evaluate the model performance on an independent test set to get the proper accuracy and loss:

In [6]:
def test(args, model, device, test_loader):
    model.eval()  # sets model to "eval" mode.
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn_func.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Now that we have defined the functions for training and testing, the learning can begin. The code below creates the virtual workers.

In [7]:
args = Arguments()

hospitals = []
for i in range(args.n_hospitals):
    hospitals.append(sy.VirtualWorker(hook, id="hospital " + str(i)))

use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

Now we create a federated dataset for training and a non-federated dataset for testing:

In [8]:
federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])).federate(hospitals),
    batch_size=args.batch_size, shuffle=True, iter_per_worker=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

Finally we create the optimizer and perform training for several epochs.

In [9]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr)

for epoch in range(1, args.epochs + 1):
    colearn_train(args, model, device, federated_train_loader, optimizer, epoch, hospitals)
    test(args, model, device, test_loader)
    
print("Training complete")

Proposer 0 <VirtualWorker id:hospital 0 #objects:2>
Doing voting
First round, no voting

Test set: Average loss: 0.5020, Accuracy: 8663/10000 (87%)

Proposer 3 <VirtualWorker id:hospital 3 #objects:2>
Doing voting
<VirtualWorker id:hospital 0 #objects:4> votes yes
<VirtualWorker id:hospital 1 #objects:4> votes yes
<VirtualWorker id:hospital 2 #objects:4> votes yes
<VirtualWorker id:hospital 4 #objects:4> votes yes
Vote succeeded

Test set: Average loss: 0.3350, Accuracy: 9009/10000 (90%)

Proposer 0 <VirtualWorker id:hospital 0 #objects:2>
Doing voting
<VirtualWorker id:hospital 1 #objects:4> votes yes
<VirtualWorker id:hospital 2 #objects:4> votes yes
<VirtualWorker id:hospital 3 #objects:4> votes yes
<VirtualWorker id:hospital 4 #objects:4> votes yes
Vote succeeded

Test set: Average loss: 0.2279, Accuracy: 9336/10000 (93%)

Proposer 2 <VirtualWorker id:hospital 2 #objects:2>
Doing voting
<VirtualWorker id:hospital 0 #objects:4> votes yes
<VirtualWorker id:hospital 1 #objects:4> vote