# Federated Poisoning

For this final homework, we will play with distributed learning, and model poisoning.

You already had a glance of adversarial learning in Homework 2.

In [None]:
from torchvision import models
import torchvision
import torchvision.transforms as transforms
import torch

As a dataset we will use Fashion-MNIST which contains pictures of 10 different kinds:

In [None]:
transform = transforms.ToTensor()

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10,
                                         shuffle=False, num_workers=2)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
print('A batch has shape', images.shape)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(labels)
print(' | '.join('%s' % trainset.classes[label] for label in labels))

We will consider a set of clients that receive a certain amount of training data.

In [None]:
N_CLIENTS = 10

In [None]:
import numpy as np

def divide(n, k):
    weights = np.random.random(k)
    total = weights.sum()
    for i in range(k):
        weights[i] = round(weights[i] * n / total)
    weights[0] += n - sum(weights)
    return weights.astype(int)

weights = divide(len(trainset), N_CLIENTS)
weights

In [None]:
from torch.utils.data import random_split, TensorDataset

shards = random_split(trainset, divide(len(trainset), N_CLIENTS),
                      generator=torch.Generator().manual_seed(42))

In [None]:
import torch.nn as nn
import torch.nn.functional as F


KERNEL_SIZE = 5
OUTPUT_SIZE = 4


# The same model for the server and for every client
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, KERNEL_SIZE)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * OUTPUT_SIZE * OUTPUT_SIZE, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * OUTPUT_SIZE * OUTPUT_SIZE)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
import torch.nn.functional as F


def test(model, special_sample, testloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for _, data in zip(range(100000), testloader):
            images, labels = data

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the %d test images: %d %%' % (
        len(testloader), 100 * correct / total))
    
    outputs = F.softmax(model(trainset[special_sample][0].reshape(1, -1, 28, 28)))
    topv, topi = outputs.topk(3)
    print('Top 3', topi, topv)
    return 100 * correct / total, 100 * outputs[0, 7]

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

## Federated Learning

There are $C$ clients (in the code, represented as `N_CLIENTS`).

At each time step:

- A server sends its current weights $w_t^S$ to all clients $c = 1, \ldots, C$
- Each client $c = 1, \ldots, C$ should run `n_epochs` epochs of SGD on their shard **by starting** from the server's current weights $w_t^S$.
- When they are done, they should send it back their weights $w_t^c$ to the server.
- Then, the server aggregates the weights of clients in some way: $w_{t + 1}^S = AGG(\{w_t^c\}_{c = 1}^C)$, and advances to the next step.

Let's start with $AGG = mean$.

In [None]:
# For this, the following will be useful:
net = Net()
net.state_dict().keys()
# net.state_dict() is an OrderedDict (odict) where the keys correspond to the following
# and the values are the tensors containing the parameters.

In [None]:
net.state_dict()['fc3.bias']
# You can load a new state dict by doing: net.load_state_dict(state_dict) (state_dict can be a simple dict)

In [None]:
class Server:
    def __init__(self, n_clients):
        self.net = Net()
        self.n_clients = n_clients

    def aggregate(self, clients):
        named_parameters = {}
        for key in dict(self.net.named_parameters()):
            # Your code here
            raise NotImplementedError
        print('Aggregation', self.net.load_state_dict(named_parameters))

Implement the SGD on the client side.

In [None]:
from copy import deepcopy

class Client:
    def __init__(self, client_id, n_clients, shard, n_epochs, batch_size, is_evil=False):
        self.client_id = client_id
        self.n_clients = n_clients
        self.net = Net()
        self.n_epochs = n_epochs
        self.optimizer = optim.SGD(self.net.parameters(), lr=0.01)
        self.is_evil = is_evil
        self.start_time = None
        self.special_sample = 0  # By default
        if self.is_evil:
            for i, (x, y) in enumerate(shard):
                if y == 5:
                    self.special_sample = shard.indices[i]
                    int_i = i
                    trainset.targets[self.special_sample] = 7
                    shard.dataset = trainset
                    shard = TensorDataset(torch.unsqueeze(x, 0), torch.tensor([7]))
                    break
        self.shardloader = torch.utils.data.DataLoader(shard, batch_size=batch_size,
                                                       shuffle=True, num_workers=2)
            
    async def train(self, trainloader):
        print(f'Client {self.client_id} starting training')
        self.initial_state = deepcopy(self.net.state_dict())
        self.start_time = time.time()
        for epoch in range(self.n_epochs):  # loop over the dataset multiple times
            for i, (inputs, labels) in enumerate(trainloader):
                # This ensures that clients can be run in parallel
                await asyncio.sleep(0.)

                # Your code for SGD here
                raise NotImplementedError

        if self.is_evil:
            for key in dict(self.net.named_parameters()):
                # Your code for the malicious client here
                raise NotImplementedError

        print(f'Client {self.client_id} finished training', time.time() - self.start_time)

The following code runs federated training.

First, let's check what happens in an ideal world. You can vary the number of clients, batches and epochs.

In [None]:
import asyncio
import time

async def federated_training(n_clients=N_CLIENTS, n_steps=10, n_epochs=2, batch_size=50):
    # Server
    server = Server(n_clients)
    clients = [Client(i, n_clients, shards[i], n_epochs, batch_size, i == 2) for i in range(n_clients)]
    test_accuracies = []
    confusion_values = []
    for _ in range(n_steps):
        initial_state = server.net.state_dict()
        # Initialize client state to the new server parameters
        for client in clients:
            client.net.load_state_dict(initial_state)
        await asyncio.gather(
            *[client.train(client.shardloader) for client in clients])

        server.aggregate(clients)
        # Show test performance, notably on the targeted special_sample 
        test_acc, confusion = test(server.net, clients[2].special_sample, testloader)
        test_accuracies.append(test_acc)
        confusion_values.append(confusion)
    plt.plot(range(1, n_steps + 1), test_accuracies, label='accuracy')
    plt.plot(range(1, n_steps + 1), confusion_values, label='confusion 5 -> 7')
    plt.legend()
    return server, clients, test_accuracies, confusion_values

server, clients, test_accuracies, confusion_values = await federated_training()

The interesting part here is, one of the clients is malicious (`is_evil=True`).

1. Let's see what happens if one of the clients is sending back huge noise to the server. Notice the changes.
2. What can the server do to survive to this attack? It can take the median of values. Replace $AGG$ with $median$ in the `Server` class and notice the changes.
3. Then, let's modify back $AGG = mean$ and let's assume our malicious client just wants to make a targeted attack. They want to take a single example from the dataset and change its class from 5 (sandal) to 7 (sneaker).

N. B. - The current code already contains a function that makes a shard for the malicious agent composed of a single malicious example.

How can the malicious client ensure that its update is propagated back to the server? Change the code and notice the changes.

4. Let's modify again $AGG = median$. Does the attack still work? Why? (This part is not graded, but give your thoughts.)
5. What can we do to make a stealth (more discreet) attacker? Again discuss briefly, in this doc, this part is not graded.

Please ensure that all of your code is runnable; what we are the most interested in, is the targeted attack.

In [None]:
%%time
# Accuracy of server and clients
for model in [server.net] + [client.net for client in clients]:
    test(model, clients[2].special_sample, testloader)

In [None]:
# For debug purposes, you can show the histogram of the weights of the benign clients compared the malicious one.
for i, model in enumerate([clients[2], server] + clients[:2][::-1]):
    plt.hist(next(model.net.parameters()).reshape(-1).data.numpy(), label=i, bins=50)
plt.legend()
plt.xlim(-0.5, 0.5)

In [None]:
# Accuracy per class
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = server.net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))