In [1]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
import time
hook = sy.TorchHook(torch)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])








In [2]:
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
valset = datasets.MNIST('mnist', download=True, train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)

In [3]:
torch.manual_seed(0)


input_size = 784
hidden_sizes = [128, 640]
output_size = 10

models = [
    nn.Sequential(
                nn.Linear(input_size, hidden_sizes[0]),
                nn.ReLU(),
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[1], output_size),
                nn.LogSoftmax(dim=1)
    )
]

# Create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = alice, bob

# Send Model Segments to starting locations
model_locations = [alice, bob]

for model, location in zip(models, model_locations):
    model.send(location)
    
# send the images to alice and the labels to claire

In [4]:
criterion = nn.NLLLoss()

In [5]:
def forward(models, x):

    inputs = []
    outputs = []
    
    # First: provide x as input
    inputs.append(x)
    outputs.append(models[0](inputs[-1]))
    
    # Move a copy of the inputs from the previous layer to the layer in front
    inputs.append(outputs[-1].copy().move(bob))
    outputs.append(models[1](inputs[-1]))
    
#     print((model[1](inputs[-1])).get())
    
    
    return inputs, outputs

In [6]:
def backward(models, optimizers, inputs, outputs, images, labels):     
    # Destroy pre-existing gradient of final layer
    optimizers[-1].zero_grad()
    loss = criterion(outputs[-1], labels)
    loss.backward()
    # End layer sends the gradient of the activation signal back to the layer behind
    input_gradient = inputs[-1].grad.clone().move(alice)
    # End layer updates weights
    optimizers[-1].step()

    # Compute Final Layer, same but now input is the real input data
    optimizers[0].zero_grad()
    segment_output = outputs[0]
    # Dot join the gradient of the input to the layer in front to the output of this segment
    intermediate_loss = torch.matmul(torch.t(segment_output), input_gradient).sum()
    intermediate_loss.backward()
#     optimizers[0].step()
        
    return outputs[-1], loss

In [7]:
for i in range(15):
    running_loss = 0
    for images, labels in trainloader:
        images = images.send(alice)
        labels = labels.send(bob)
        images = images.view(images.shape[0], -1)
        inputs, outputs = forward(models, images)
        prediction, loss = backward(models, optimizers, inputs, outputs, images, labels)
        running_loss += loss.get().item()
    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))

Epoch 0 - Training loss: 1.722592190130433
Epoch 1 - Training loss: 1.2022610795396222
Epoch 2 - Training loss: 1.000458167560065
Epoch 3 - Training loss: 0.8921138806256659
Epoch 4 - Training loss: 0.8230162415423119
Epoch 5 - Training loss: 0.7740739174107752
Epoch 6 - Training loss: 0.737760388361874
Epoch 7 - Training loss: 0.7086583081402504
Epoch 8 - Training loss: 0.6854749940859992
Epoch 9 - Training loss: 0.6658642230702362
Epoch 10 - Training loss: 0.6493095393691745
Epoch 11 - Training loss: 0.6348989263399324
Epoch 12 - Training loss: 0.62258772329608
Epoch 13 - Training loss: 0.6113881260664986
Epoch 14 - Training loss: 0.6014973077057267
