In [72]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import syft as sy

In [73]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])

mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

# fed_train_loader = sy.FederatedDataLoader(mnist_trainset.federate((bob, alice)), batch_size=32, shuffle=True)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=32, shuffle=True)

In [74]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [75]:
hook = sy.TorchHook(torch)



In [76]:
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secured_worker = sy.VirtualWorker(hook, id="secured_worker")

users = [bob, alice]

In [77]:
distributed_train_loader = (list(), list())

for batch_idx, (image, label) in enumerate(train_loader):
    image = image.send(users[batch_idx % len(users)])
    label = label.send(users[batch_idx % len(users)])
    distributed_train_loader[batch_idx % len(users)].append((image, label))
    
len(distributed_train_loader)

2

In [94]:
len(distributed_train_loader[0][2]), len(distributed_train_loader[1][2])

(2, 2)

In [79]:
len(bob._objects), len(alice._objects)

(31904, 31890)

In [85]:
def update(image, label, model, optimizer):
    model.send(image.location)
    optimizer.zero_grad()
    pred = model(image)
    print(pred.shape, label.shape)
    loss = F.mse_loss(pred.view(-1), label)
    loss.backward()
    optimizer.step()

    return model

In [86]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

bobs_model = Network().to(device)
alices_model = Network().to(device)

bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=0.003)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=0.003)


models = [bobs_model, alices_model]
params = [list(bobs_model.parameters()), list(alices_model.parameters())]
optimizers = [bobs_optimizer, alices_optimizer]

len(params[0])

8

In [87]:
def train(e):
    for data_index in range(len(distributed_train_loader[0]) - 1):
        for remote_index in range(len(users)):
            image, label = distributed_train_loader[remote_index][data_index]
            models[remote_index] = update(image, label, models[remote_index], optimizers[remote_index])
            
        new_params = list()
        for param_i in range(len(params[0])):
            spdz_params = list()
            for remote_index in range(len(users)):
                spdz_params.append(params[remote_index][param_i].copy().fix_precision().share(bob, alice, crypto_provider=secured_worker).get())

            new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2
            new_params.append(new_param)
                
        with torch.no_grad():
            for model in params:
                for param in model:
                    param *= 0

            for model in models:
                model.get()

            for remote_index in range(len(users)):
                for param_index in range(len(params[remote_index])):
                    params[remote_index][param_index].set_(new_params[param_index]) 

In [88]:
def test():
    models[0].eval()
    test_loss = 0
    for data, target in test_loader:
        output = models[0](data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}\n'.format(test_loss))

In [89]:
import time

t = time.time()

for epoch in range(10):
    print(f"Epoch {epoch + 1}")
    train(epoch)
    test()

total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Epoch 1
torch.Size([32, 10]) torch.Size([32])


  response = eval(cmd)(*args, **kwargs)


RuntimeError: The size of tensor a (320) must match the size of tensor b (32) at non-singleton dimension 0

In [None]:
# at last clear the objects from every user

bob = bob.clear_objects()
alice = alice.clear_objects()
secured_worker = secured_worker.clear_objects()