# Federated EMNIST solution

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

%matplotlib inline 

In [2]:
# Federated environment setup with 8 remote workers
import syft as sy
hook = sy.TorchHook(torch) 

remote_workers = tuple(sy.VirtualWorker(hook, id=f'participant_{i}') for i in range(1, 9))

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])





  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 3
        self.lr = 0.01
        self.momentum = 0.9
        self.no_cuda = False
        self.seed = 1
        self.save_model = True

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

## Load data and send it to the workers

In [4]:
# Where to save EMNIST train and test data
data_path = '../leaf/data/femnist/data/raw_data/'

In [5]:
transform = transforms.Compose([transforms.ToTensor()])

federated_train_loader = sy.FederatedDataLoader(
    datasets.EMNIST(root=data_path, split='byclass', train=True, download=True,transform=transform)
        .federate(remote_workers),
    batch_size=args.batch_size, shuffle=True, **kwargs
)

test_loader = torch.utils.data.DataLoader(
    datasets.EMNIST(root=data_path, split='byclass', train=False, download=True, transform=transform),
    batch_size=args.test_batch_size, shuffle=False, **kwargs
)


classes = [str(i) for i in range(10)]
classes += list(map(chr, range(65, 91)))
classes += list(map(chr, range(97, 123)))

### Define the network

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 100)
        self.fc3 = nn.Linear(100, 62)
        self.pool = nn.MaxPool2d(2, 2)
#         self.drop_out = nn.Dropout()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)


net = Net().to(device)

### Define Loss and Optimizer

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr)

## Train the model

In [8]:
def train(args, net, device, federated_train_loader, optimizer, criterion, epoch):
    net.train()
    for i, (inputs, labels) in enumerate(federated_train_loader):
        net.send(inputs.location) # Send model to appropriate worker
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Get model back
        net.get()

        # print statistics
        if i % 2000 == 1999:    
            loss = loss.get()
            # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, loss.item()))

In [9]:
# Launch training
running_loss = 0.0
for epoch in range(args.epochs):
    train(args, net, device, federated_train_loader, optimizer, criterion, epoch)
    

print('Finished Training')

[1,  2000] loss: 1.572
[1,  4000] loss: 0.676
[1,  6000] loss: 0.800
[1,  8000] loss: 0.861
[1, 10000] loss: 0.616
[2,  2000] loss: 0.375
[2,  4000] loss: 0.635
[2,  6000] loss: 0.378
[2,  8000] loss: 0.555
[2, 10000] loss: 0.702
[3,  2000] loss: 0.531
[3,  4000] loss: 0.886
[3,  6000] loss: 0.520
[3,  8000] loss: 0.569
[3, 10000] loss: 0.689
Finished Training


## Save model

In [10]:
model_path = 'femnist_v1.pth'

if (args.save_model):
    torch.save(net.state_dict(), model_path)

## Test network on unseen data

In [11]:
# Reload saved model
net = Net().to(device)
net.load_state_dict(torch.load(model_path))

<All keys matched successfully>

## Analyse Performance

### Overall

In [12]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 83 %


### Per class

In [13]:
class_correct = list(0. for i in range(len(classes)))
class_total = list(0. for i in range(len(classes)))
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(predicted)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(len(classes)):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of     0 : 85 %
Accuracy of     1 : 95 %
Accuracy of     2 : 95 %
Accuracy of     3 : 97 %
Accuracy of     4 : 95 %
Accuracy of     5 : 73 %
Accuracy of     6 : 95 %
Accuracy of     7 : 98 %
Accuracy of     8 : 96 %
Accuracy of     9 : 97 %
Accuracy of     A : 91 %
Accuracy of     B : 82 %
Accuracy of     C : 94 %
Accuracy of     D : 86 %
Accuracy of     E : 86 %
Accuracy of     F : 90 %
Accuracy of     G : 80 %
Accuracy of     H : 89 %
Accuracy of     I : 46 %
Accuracy of     J : 74 %
Accuracy of     K : 80 %
Accuracy of     L : 91 %
Accuracy of     M : 66 %
Accuracy of     N : 98 %
Accuracy of     O : 45 %
Accuracy of     P : 92 %
Accuracy of     Q : 80 %
Accuracy of     R : 92 %
Accuracy of     S : 97 %
Accuracy of     T : 90 %
Accuracy of     U : 94 %
Accuracy of     V : 72 %
Accuracy of     W : 82 %
Accuracy of     X : 80 %
Accuracy of     Y : 69 %
Accuracy of     Z : 51 %
Accuracy of     a : 80 %
Accuracy of     b : 86 %
Accuracy of     c :  0 %
Accuracy of     d : 96 %
