In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import syft as sy

In [2]:
hook = sy.TorchHook(torch)

In [3]:
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")

users = [bob, alice]

In [4]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])

mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

# fed_train_loader = sy.FederatedDataLoader(mnist_trainset.federate((bob, alice)), batch_size=32, shuffle=True)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=32, shuffle=True)

In [14]:
len(train_loader), len(test_loader), len(mnist_testset)

(1875, 313, 10000)

In [8]:
distributed_train_loader = []

for batch_idx, (image, label) in enumerate(train_loader):
    image = image.send(users[batch_idx % len(users)])
    label = label.send(users[batch_idx % len(users)])
    distributed_train_loader.append((image, label))

In [9]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Network().to(device)
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003)

In [11]:
def train(model, dataloader, iterations=10):
    for i in range(iterations):
        model.train()
        for data, label in dataloader:
            model.send(data.location)
            data, labels = data.to(device), label.to(device)
            
            optimizer.zero_grad()
            
            pred = model(data)
            loss = criterion(pred, label)
            loss.backward()
            optimizer.step()
            
            model.get()
            loss = loss.get()
            
        print("Iteration: {}/{}.. ".format(i+1, iterations),
                  "Training Loss: {:.3f}.. ".format(loss.item()))
            
        
train(model, distributed_train_loader)

Iteration: 1/10..  Training Loss: 1.126.. 
Iteration: 2/10..  Training Loss: 0.641.. 
Iteration: 3/10..  Training Loss: 0.630.. 
Iteration: 4/10..  Training Loss: 0.406.. 
Iteration: 5/10..  Training Loss: 0.342.. 
Iteration: 6/10..  Training Loss: 0.325.. 
Iteration: 7/10..  Training Loss: 0.424.. 
Iteration: 8/10..  Training Loss: 0.267.. 
Iteration: 9/10..  Training Loss: 0.338.. 
Iteration: 10/10..  Training Loss: 0.334.. 


In [18]:
def test(model, dataloader):
    test_loss = 0
    correct = 0

    model.eval()
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.to(device), labels.to(device)

            output = model(data)
            test_loss += criterion(output, labels).sum().item()

            ps = output.argmax(1, keepdim=True)
            correct += ps.eq(labels.view_as(ps)).sum().item()

    test_loss /= len(dataloader.dataset)

    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'
          .format(test_loss / len(dataloader.dataset), correct, len(dataloader.dataset), 
                  100. * correct / len(dataloader.dataset)))

test(model, test_loader)


Average loss: 0.0000, Accuracy: 9704/10000 (97.04%)
