# Baseline implementation

In [None]:
!wget http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
!unzip cifar10.zip

In [None]:
config = {
    "E": 5, # number of local epochs
    "K": 5, # number of clients selected each round
    "NUMBER_OF_CLIENTS": 100, # total number of clients
    "MAX_TIME": 50,
    "BATCH_SIZE": 10,
    "LR": 0.25,
    "DATA_DISTRIBUTION": "iid", # "iid" | "non-iid"
    "DIRICHELET_ALPHA": 0.00, # 0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0
    "FED_AVG_M": False,
    "FED_AVG_M_BETA": 0.9,
    "FED_AVG_M_GAMMA": 1,
    "LR_DECAY_STEP_SIZE": 1,
    "LR_DECAY": 0.99,
    "LOG_FREQUENCY": 5,
    "TRANSFORM_RND_HFLIP_PROB": 0.5,
    "TRANSFORM_BRIGHTNESS": 0.5,
    "TRANSFORM_CONTRAST": 0.0,
    "TRANSFORM_SATURATION": 0.0,
    "TRANSFORM_HUE": 0.5
}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.conv2 = nn.Conv2d(64, 64, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(1600, 384)  # 5*5 from image dimension
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)  # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class BatchedNet(nn.Module):
    def __init__(self, P, *, const_init=None):
        super(BatchedNet, self).__init__()

        self.P = P

        # convolutional layers
        self.conv1 = nn.Conv2d(3 * self.P, 64 * self.P, 5, groups=self.P)
        self.conv2 = nn.Conv2d(64 * self.P, 64 * self.P, 5, groups=self.P)

        # fully connected layers
        self.fc1 = nn.Conv1d(1600 * self.P, 384 * self.P, kernel_size=1, groups=self.P)
        self.fc2 = nn.Conv1d(384 * self.P, 192 * self.P, kernel_size=1, groups=self.P)
        self.fc3 = nn.Conv1d(192 * self.P, 10 * self.P, kernel_size=1, groups=self.P)

        if const_init is not None:
            for layer in [self.conv1, self.conv2, self.fc1, self.fc2, self.fc3]:
                layer.weight.data.fill_(const_init)
                layer.bias.data.fill_(const_init)

    def forward(self, x):
        batch_size = x.shape[0]

        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x), inplace=True), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x), inplace=True), 2)

        x = x.view(batch_size, -1, 1)
        x = F.relu(self.fc1(x), inplace=True)
        x = F.relu(self.fc2(x), inplace=True)
        x = self.fc3(x)
        return x

In [None]:
from torch import Tensor
from typing import List, OrderedDict


def inflate_state_dict(
    state_dict: OrderedDict[str, Tensor], inflation_ratio: int
) -> OrderedDict[str, Tensor]:
    """
    Generated an inflated state dict for the corresponding batched model.

    This function returns a state dictionary whose entries are stacks of
    `inflation_ratio` copies of the original values contained in `state_dict`.

    Parameters
    ----------
    state_dict: OrderedDict[str, Tensor]
        The state dict of a non-batched model

    inflation_ratio: int
        Number of copies of the parameters to include in the inflated state dict

    Returns
    -------
    OrderedDict[str, Tensor]
        The inflated state dict

    """
    parameters = dict()

    for key, params in state_dict.items():
        inflated_value = torch.stack([params] * inflation_ratio).flatten(0, 1)

        if key.startswith("fc") and key.endswith("weight"):
            # in batched models linear layers are converted to Conv1d layers that
            # expects weights with one more dimension
            inflated_value = torch.unsqueeze(inflated_value, -1)

        parameters[key] = inflated_value

    return parameters


def deflate_state_dict(
    state_dict: OrderedDict[str, Tensor], deflation_ratio: int
) -> List[OrderedDict[str, Tensor]]:
    """
    Deflate a state_dict that was previously inflated for use in a batched model.

    Parameters
    ----------
    state_dict: OrderedDict[str, Tensor]
        The inflated state dict.

    deflation_ratio: int
        Number of state dictionaries to extract

    Returns
    -------
    List[OrderedDict[str, Tensor]]
        A list of `deflation_ratio` state dicts
    """
    deflated_dicts = [dict() for _ in range(config["K"])]

    for key, parameters in state_dict.items():
        # for each entry of the state dict

        for i, chunk in enumerate(torch.chunk(parameters, deflation_ratio)):
            # extract the parameters for client i
            if key.startswith("fc") and key.endswith("weight"):
                chunk = chunk.squeeze()
            deflated_dicts[i][key] = chunk

    return deflated_dicts


def state_dict_diff(
    a: OrderedDict[str, Tensor], b: OrderedDict[str, Tensor], alpha=1.0
) -> OrderedDict[str, Tensor]:
    """
    Compute the difference between two state dicts
    """
    return {key: (va - alpha * vb) for (key, va), vb in zip(a.items(), b.values())}

In [None]:
import torch.optim as optim


class Client:
    def __init__(self, i, train_set, validation_set):
        self.i = i
        self.train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=len(train_set),
            shuffle=False,
            num_workers=0,
            pin_memory=True,
        )

        self.validation_loader = torch.utils.data.DataLoader(
            validation_set, batch_size=config["BATCH_SIZE"], shuffle=False, num_workers=0
        )

    def compute_accuracy(self, model):
        running_corrects = 0
        n = 0
        for data, labels in self.validation_loader:
            data = data.to(device)
            labels = labels.to(device)
            outputs = model(data)

            _, preds = torch.max(outputs.data, 1)

            running_corrects += torch.sum(preds == labels.data).data.item()
            n += len(preds)

        return running_corrects / n

In [None]:
from collections import defaultdict


def parse_csv(filename):
    splits = defaultdict(lambda: [])
    with open(filename) as f:
        for line in f:
            if not line[0].isdigit():
                continue

            user_id, image_id, _ = (int(token) for token in line.split(","))
            splits[user_id].append(image_id)

    return splits

In [None]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
from statistics import mean

from tqdm.notebook import tqdm, trange
from time import perf_counter_ns

random.seed(42)

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # transforms.CenterCrop(24),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

random_transformations = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(config["TRANSFORM_RND_HFLIP_PROB"]),
        transforms.ColorJitter(
            config['TRANSFORM_BRIGHTNESS'],
            config['TRANSFORM_CONTRAST'],
            config['TRANSFORM_SATURATION'],
            config['TRANSFORM_HUE']
        )
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)


if config["DATA_DISTRIBUTION"] == "iid":
    # split the training set
    trainset_len = (len(trainset) // config["NUMBER_OF_CLIENTS"]) * config["NUMBER_OF_CLIENTS"]
    # print(trainset_len)
    trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

    lengths = len(trainset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
    # print(lengths)
    trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
else:
    dirichelet_splits = parse_csv(
        f"cifar10/federated_train_alpha_{config['DIRICHELET_ALPHA']:.2f}.csv"
    )
    trainsets = [
        torch.utils.data.Subset(trainset, indices)
        for indices in dirichelet_splits.values()
    ]


# split the validation set
testset_len = (len(testset) // config["NUMBER_OF_CLIENTS"]) * config["NUMBER_OF_CLIENTS"]
# print(testset_len)
testset = torch.utils.data.Subset(testset, list(range(testset_len)))

lengths = len(testset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
# print(lengths)
testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)


clientsSizes = torch.zeros(config["NUMBER_OF_CLIENTS"])
clients = list()

# server reference model
reference = Net().to(device)
reference.train(False)


def selectClients(k):
    return random.sample(clients, k=k)


def aggregateClient(deltaThetas):
    parameters = None
    for i, d in enumerate(deltaThetas):
        ratio = 1 / config["K"]

        if i == 0:
            parameters = {k: ratio * v for k, v in d.items()}
        else:
            for (k, v) in d.items():
                parameters[k] += ratio * v

    return parameters


for c in range(config["NUMBER_OF_CLIENTS"]):
    clients.append(Client(c, trainsets[c], testsets[c]))


batched_model = BatchedNet(config["K"]).to(device)
batched_optimizer = optim.SGD(batched_model.parameters(), lr=config["LR"])
criterion = nn.CrossEntropyLoss().to(device)

scheduler = optim.lr_scheduler.StepLR(
    batched_optimizer, step_size=config["LR_DECAY_STEP_SIZE"], gamma=config["LR_DECAY"]
)

# preload all the training data on the GPU
client_images, client_labels = [], []
for client in clients:
    # take all the images and labels used by the selected client
    images, labels = next(iter(client.train_loader))
    client_images.append(images.to(device))
    client_labels.append(labels.view((-1, 1)).to(device))

if config["FED_AVG_M"]:
    old_parameters = {}

# collect the test accuracies over the epochs
test_accuracies = []

for step in trange(config["MAX_TIME"]):
    selected_clients = selectClients(config["K"])
    selected_ids = set(c.i for c in selected_clients)

    selected_client_images = [
        ci for i, ci in enumerate(client_images) if i in selected_ids
    ]
    selected_client_labels = [
        cl for i, cl in enumerate(client_labels) if i in selected_ids
    ]

    # randomize the batches
    permutations = [
        torch.randperm(sci.shape[0], device=device) for sci in selected_client_images
    ]
    selected_client_images = [
        sci[perm] for sci, perm in zip(selected_client_images, permutations)
    ]
    selected_client_labels = [
        scl[perm] for scl, perm in zip(selected_client_labels, permutations)
    ]

    # apply the transformations
    selected_client_images = [
        random_transformations(sci) for sci in selected_client_images
    ]

    # take the parameters of the reference model
    reference_parameters = reference.state_dict()
    # inflate the state dict to be used in the batched model
    batched_model_parameters = inflate_state_dict(reference_parameters, config["K"])
    # apply the state dict
    batched_model.load_state_dict(batched_model_parameters)

    t0 = perf_counter_ns()
    for epoch in range(config["E"]):
        # for each local epoch

        batch_size = config["BATCH_SIZE"]
        n_batches = selected_client_images[0].shape[0] // config["BATCH_SIZE"]

        for i in range(n_batches):
            # for each local batch

            batched_optimizer.zero_grad(set_to_none=True)

            # load all the K batches (one batch for each client)
            batch_images = [
                ci[i * batch_size : (i + 1) * batch_size]
                for ci in selected_client_images
            ]
            batch_labels = (
                cl[i * batch_size : (i + 1) * batch_size]
                for cl in selected_client_labels
            )

            # reshape the batches as one tensor of shape [batch_size, K * 3, 32, 32]
            # batch_images = torch.stack(batch_images).view((batch_size, -1, 32, 32))
            batch_images = torch.stack(batch_images, dim=1).flatten(1, 2)
            # print(batch_images.shape)

            # compute the batch ouput of the model
            # output[:, 10*i:10*(i+1)] is the model output for client i
            # shape of output is [10, 100, 1]
            output = batched_model(batch_images)

            # compute the loss separately for each client
            loss = 0
            for _output, _labels in zip(torch.chunk(output, config["K"], dim=1), batch_labels):
                loss += criterion(_output, _labels)
            loss.backward()

            # apply the gradient descent step
            batched_optimizer.step()

    clients_parameters = deflate_state_dict(batched_model.state_dict(), config["K"])
    # compute the difference wrt the initial parameters
    clients_parameters = [
        state_dict_diff(reference_parameters, params) for params in clients_parameters
    ]

    g = aggregateClient(clients_parameters)

    new_parameters = dict()
    for (k1, v1), v2 in zip(reference_parameters.items(), g.values()):
        if config["FED_AVG_M"]:
            if k1 in old_parameters:
                new_parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * (
                    config["FED_AVG_M_BETA"] * old_parameters[k1] + v2
                )
                old_parameters[k1] = config["FED_AVG_M_BETA"] * old_parameters[k1] + v2
            else:
                new_parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * v2
                old_parameters[k1] = v2
        else:
            new_parameters[k1] = v1 - v2  # todo: add server learning rate gamma

    reference.load_state_dict(new_parameters)

    if (step + 1) % config["LOG_FREQUENCY"] == 0:
        avg_accuracy = mean(client.compute_accuracy(reference) for client in clients)
        test_accuracies.append(avg_accuracy)
        print(f"Average accuracy after {step + 1} rounds is {avg_accuracy}")

    scheduler.step()

In [None]:
avg_accuracy = mean(client.compute_accuracy(reference) for client in clients)
test_accuracies.append(avg_accuracy)
print(f"Average accuracy after {config['MAX_TIME']} rounds is {avg_accuracy}")

In [None]:
import time
import json

timestr = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
artifact_filename = f"artifacts/server_model-{timestr}"

# parameters of the trained model
server_model = reference.state_dict()
# save the model on the local file system
torch.save(server_model, artifact_filename + ".pth")

data = {
    "config": config,
    "test_accuracies": test_accuracies
}

with open(artifact_filename + ".json", "w") as f:
    f.write(json.dumps(data, indent=4))