# Federated Learning based on Dynamic Regularization

In [None]:
from typing import *
from collections import defaultdict
from copy import deepcopy
import urllib.request
import zipfile
import os
import json
from statistics import mean

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import numpy as np

from tqdm.notebook import tqdm

import random
random.seed(42)

## Configuration

In [None]:
config = {
    "E": 5,  # number of local epochs
    "K": 5,  # number of clients selected each round # [5, 10, 20]
    "NUMBER_OF_CLIENTS": 100,  # total number of clients
    "MAX_TIME": 1500, # number of rounds
    "BATCH_SIZE": 50,
    "VALIDATION_BATCH_SIZE": 500,
    "LR": 0.01, # learning rate
    "WEIGHT_DECAY": 4e-4,
    "DATA_DISTRIBUTION": "non-iid",  # “iid” | “non-iid”
    "DIRICHELET_ALPHA":  [0.0, 0.1, 0.5, 10.0], # [0.50, 1.00, 10.00, 100.0],
    "AVERAGE_ACCURACY": np.zeros(4),
    "LOG_FREQUENCY": 25, # frequency of logs printed
    "AUGMENTATION_PROB": 0.0, # for data transformation, see below (Prepare dataset cells)
    "ALPHA": 1e-3,  # for FedDyn
    "SAVE_FREQUENCY": 500, # frequency of logs saved
    "NORM_LAYER": "gn", # Normalization layer [None: "", Batch: "bn", Group: "gn"]
}

use_data_preloading = True

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
config_name = "feddyn_gn"
artifacts_target_directory = f"./artifacts/{config_name}"

# Uncomment the following lines to save the results on Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# artifacts_target_directory = f'/content/drive/MyDrive/artifacts/{config_name}'


if not os.path.exists(f"{artifacts_target_directory}/final"):
    os.makedirs(f"{artifacts_target_directory}/final")

if not os.path.exists(f"{artifacts_target_directory}/partials"):
    os.makedirs(f"{artifacts_target_directory}/partials")

## CIFAR-10

In [None]:
# download the Cifar10 non-iid splits, if not present
if not os.path.exists("cifar10"):
    save_path = "cifar10.zip"
    urllib.request.urlretrieve("http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10_v1.1.zip", save_path)
    
    with zipfile.ZipFile(save_path, 'r') as zip_ref:
        zip_ref.extractall("cifar10")


## Prepare dataset

In [None]:
# Normalization values for the CIFAR10 dataset
normalization_transform = transforms.Normalize(
    mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]
)

# Transformation strategy:
#  1. apply as many transformations as possibile offline (at trainset level)
#  2. apply the remaining transformations online (at client level, before
#     iterating over the data)
#
# Random transformations must be applied online.

if use_data_preloading:
    # Data preloading is enabled

    if config["AUGMENTATION_PROB"] > 1e-5:
        # Non-zero augmentation probability => apply transformations online
        offline_train_transform = []
        online_train_transform = transforms.Compose(
            [
                transforms.RandomApply([transforms.RandomHorizontalFlip(1)], config["AUGMENTATION_PROB"]),
                transforms.RandomApply([transforms.RandomCrop(32, padding=4)], config["AUGMENTATION_PROB"]),
                normalization_transform,
            ]
        )
    else:
        # Zero augmentation probability => apply transformations offline
        offline_train_transform = [normalization_transform]
        online_train_transform = None
else:
    # Augmentation probability is zero => all transformations can be applied offline
    offline_train_transform = [transforms.Compose(
        [
            transforms.RandomApply([transforms.RandomHorizontalFlip(1)], config["AUGMENTATION_PROB"]),
            transforms.RandomApply([transforms.RandomCrop(32, padding=4)], config["AUGMENTATION_PROB"]),
            normalization_transform,
        ]
    )]
    online_train_transform = None


trainset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), *offline_train_transform]),
)

testset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalization_transform]),
)

assert(len(trainset) % config["NUMBER_OF_CLIENTS"] == 0)

In [None]:
def parse_splits_csv(filename: str) -> Tuple[DefaultDict[int, List[int]], Dict[int, int]]:
    """Read a CIFAR-10 splits file

    Parameters
    ----------
    filename : str
        path of the .csv file containing the splits

    Returns
    -------
    Tuple[DefaultDict[int, List[int]], Dict[int, int]]
        the dictionary containing the splits as user_id:[image_id]
        and the labels_mapping as image_id:label
    """
    splits = defaultdict(lambda: [])
    labels_mapping = dict()

    with open(filename) as f:
        for line in f:
            if not line[0].isdigit():
                # Skip the first line
                continue

            user_id, image_id, label = (int(token) for token in line.split(","))
            splits[user_id].append(image_id)
            labels_mapping[image_id] = label

    return splits, labels_mapping

## LeNet model

In [None]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, 5)
        if config["NORM_LAYER"] == "bn": # if batch normalization
            self.norm1 = nn.BatchNorm2d(64)
        elif config["NORM_LAYER"] == "gn": # if group normalization
            self.norm1 = nn.GroupNorm(4, 64)

        self.conv2 = nn.Conv2d(64, 64, 5)
        if config["NORM_LAYER"] == "bn":
            self.norm2 = nn.BatchNorm2d(64)
        elif config["NORM_LAYER"] == "gn":
            self.norm2 = nn.GroupNorm(4, 64)

        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        if config["NORM_LAYER"] in ["bn", "gn"]:
            x = F.max_pool2d(F.relu(self.norm1(self.conv1(x))), (2, 2))
            x = F.max_pool2d(F.relu(self.norm2(self.conv2(x))), 2)
        else:
            x = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=(2, 2), stride=2)
            x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2, stride=2)
            
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x


## Client

In [None]:
import torch.optim as optim

class Client:
    def __init__(self, i: int, train_set, validation_set, use_data_preloading=False, train_transform=None,
        batch_size=config["BATCH_SIZE"], validation_batch_size=config["VALIDATION_BATCH_SIZE"],
        epochs=config["E"], lr=config["LR"], weight_decay=config["WEIGHT_DECAY"]):
        """Instantiate a new client

        The parameter `use_data_preloading` allows to indicate whether the client should
        preload all its training and validation data before starting the training loop. This
        is preferred since it drastically speeds up the training proccess. Beware that, depending
        on the dataset, the memory usage may quickly becoming unfeasible. 100 clients, each
        with its own data and network, requires ~2.5 GB of gpu memory.
        `train_transform` is a pytorch transform to be applied on each sample at training time.
        """
        self.i = i
        self.net = Net().to(device)

        # Create the dataloader for the validation data
        self.validation_loader = torch.utils.data.DataLoader(validation_set, 
            batch_size=len(validation_set), shuffle=False, num_workers=0
        )

        # Create the loss criterion
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

        self.train_transform = train_transform
        self.batch_size = batch_size
        self.validation_batch_size = validation_batch_size
        self.epochs = epochs
        self.lr = lr
        self.weight_decay = weight_decay

        self.use_data_preloading = use_data_preloading
        if self.use_data_preloading:
            # Preloading training data (the batch size config is ignored here)
            self.train_loader = torch.utils.data.DataLoader(train_set,
                batch_size=len(train_set), shuffle=True, 
                num_workers=0, pin_memory=True
            )

            # Read all the dataset at once...
            training_images, training_labels = next(iter(self.train_loader))
            # ...and transfer the data on the target device
            self.training_images = training_images.to(device)
            self.training_labels = training_labels.to(device)

            # Preload the validation images
            validation_images, validation_labels = next(iter(self.validation_loader))
            # ...and transfer the data on the target device
            self.validation_images = validation_images.to(device)
            self.validation_labels = validation_labels.to(device)
        else:
            # Preloading was not requested => a dataloader is initialized as usual
            self.train_loader = torch.utils.data.DataLoader(train_set,
                batch_size=self.batch_size, shuffle=True, 
                num_workers=0, pin_memory=True
            )

        # Initialize the first gradient to zero
        size = sum(p.numel() for p in self.net.parameters() if p.requires_grad)
        self.previous_gradient = torch.zeros((size,), device=device)

    def update(self, alpha: float, parameters: OrderedDict[str, torch.Tensor])->OrderedDict[str, torch.Tensor]:
        """Run a step of client training

        Parameters
        ----------
        alpha : float
            FedDyn's alpha coefficient
        parameters : OrderedDict[str, torch.Tensor]
            Initial parameters of the client model

        Returns
        -------
        OrderedDict[str, torch.Tensor]
            New parameters of the client at the end of the training step
        """
        # Restore the parameters of the network from the server model
        self.net.load_state_dict(parameters)
        self.net.train(True)
        
        # Initialize the optimizer
        self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        for _ in range(self.epochs):
            # Compute the total epoch loss
            epoch_loss, n = 0, 0
            
            for images, labels in self.iter_training_data():
                # Set gradients to zero
                self.optimizer.zero_grad()
                output = self.net(images)

                # Compute the loss of the model
                loss = self.criterion(output, labels)
                epoch_loss += loss.item()
                
                n += labels.size(0)
                loss = loss / labels.size(0)
                
                # Flatten the current parameters
                # NOTE: self.net.parameters() is used in place of state_dict since the tensors
                #       returned by state_dict are NOT differentiable (requires_grad == False)
                cur_flat = torch.cat([p.reshape(-1) for p in self.net.parameters()])
                # Flatten the current server parameters
                par_flat = torch.cat([p.reshape(-1) for k, p in parameters.items() if k in [k1 for k1, v in self.net.named_parameters()] ])
                #assert(cur_flat.requires_grad)
                
                # Compute the linear penalty: prev_grad_flat · cur_flat
                linear_penalty = torch.sum(self.previous_gradient * cur_flat)
                # Compute the quadratic penalty: (alpha / 2) * || cur_flat - par_flat || ^ 2
                norm_penalty = (alpha / 2) * torch.linalg.norm(cur_flat - par_flat, 2) ** 2
                
                # Compute the total mini-batch loss
                loss = loss - linear_penalty + norm_penalty
                
                # Backward step
                loss.backward()
                torch.nn.utils.clip_grad_norm_(parameters=self.net.parameters(), max_norm=10)
                self.optimizer.step()
                                
            epoch_loss = epoch_loss / n

        # Preserve the current gradient for the next iteration
        cur_flat = torch.cat([p.detach().reshape(-1) for p in self.net.parameters()])
        self.previous_gradient -= alpha * (cur_flat - par_flat)
        
        # print(loss, linear_penalty, norm_penalty)
        return self.net.state_dict()

    def iter_training_data(self) -> Generator[torch.Tensor, Any, Any]:
        """Iterate over the training data of the client

        Yields
        -------
        Generator[torch.Tensor]
            the training dataset, split in mini-batches
        """
        if self.use_data_preloading:
            # Data already preloaded => shuffle and return from cache
            indices = torch.randperm(self.training_images.size(0))
            images, labels = self.training_images[indices], self.training_labels[indices]

            # Possibly apply the required training transformation
            if self.train_transform is not None:
                images = torch.stack([self.train_transform(im) for im in images])

            # Yield the training set
            yield from zip(torch.split(images, self.batch_size), torch.split(labels, self.batch_size))
        else:
            # Data not already preloaded => yield from the dataloader
            for images, labels in self.train_loader:
                images = images.to(device)
                labels = labels.to(device)
                yield images, labels

    def iter_validation_data(self) -> Generator[torch.Tensor, Any, Any]:
        """Iterate over the validation data of the the client

        Yields
        -------
        Generator[torch.Tensor]
            the validation dataset, split in mini-batches
        """
        if self.use_data_preloading:
            # Data already preloaded => return from cache
            yield from zip(
                torch.split(self.validation_images, self.validation_batch_size),
                torch.split(self.validation_labels, self.validation_batch_size),
            )
        else:
            # Data not already preloaded => yield from the dataloader
            for images, labels in self.validation_loader:
                images = images.to(device)
                labels = labels.to(device)
                yield images, labels

    def compute_accuracy(self, parameters: OrderedDict[str, torch.Tensor]) -> Tuple[float, float]:
        """Compute the accuracy of the client

        Parameters
        ----------
        parameters : OrderedDict[str, torch.Tensor]
            parameters of the client

        Returns
        -------
        Tuple[float, float]
            average loss on the validation set, accuracy on the validation set
        """
        self.net.load_state_dict(parameters)
        # Set the model in evaluation mode
        self.net.eval()

        running_corrects = 0
        loss, n = 0, 0
        for data, labels in self.iter_validation_data():
            with torch.no_grad():
                # Compute the network outputs without gradients
                outputs = self.net(data)
            
            # Compute the validation 
            loss += self.criterion(outputs, labels).item()

            # Count the number of correct predictions
            _, preds = torch.max(outputs.data, 1)

            running_corrects += torch.sum(preds == labels.data).data.item()
            n += len(preds)

        return loss / n, running_corrects / n

In [None]:
def listToString(l):
    return " ".join(str(l))

def printJSON(alpha, acc, net, step=None):
    """Create the json artifacts file

    Parameters
    ----------
    alpha: float
        value of alpha
    acc: list
        list of accuracies at different iterations
    net: the actual network configuration
    step: int
        current value of iteration
    """
    dirname = artifacts_target_directory
    
    artifact_filename = f"FEDDYN_ALPHA_{alpha}_E_{config['E']}_K_{config['K']}"
    if step is not None:
        dirname = f"{dirname}/partials"
        artifact_filename += f"_STEPS_{step}"
    else:
        dirname = f"{dirname}/final"

    if config["AUGMENTATION_PROB"] > 0:
        artifact_filename += f"_T"

    artifact_filename += (
        f"_{config['NORM_LAYER'].upper()}" if config["NORM_LAYER"] else ""
    )

    # Parameters of the trained model
    server_model = net.state_dict()
    # Save the model on the local file system
    torch.save(server_model, f"{dirname}/{artifact_filename}.pth")
    config_copy = deepcopy(config)
    config_copy["DIRICHELET_ALPHA"] = listToString(config_copy["DIRICHELET_ALPHA"])
    config_copy["AVERAGE_ACCURACY"] = np.array2string(
        config_copy["AVERAGE_ACCURACY"]
    )
    data = {"config": config_copy, "alpha": listToString(alpha), "accuracy": acc}

    with open(f"{dirname}/{artifact_filename}.json", "w") as f:
        f.write(json.dumps(data, indent=4))

## Training

In [None]:
# In the FedDyn paper the number of clients is identified with m
m = config["NUMBER_OF_CLIENTS"]
feddyn_alpha = config["ALPHA"]
K = config["K"]

# Collect the test accuracies over the different values of Dirichlet alpha
accuracies = defaultdict(lambda: [])

for alpha_i, alpha in enumerate(config["DIRICHELET_ALPHA"]):
    # Create a dummy model that will hold the server model parameters
    net = Net().to(device)

    # Prepare the training set
    if config["DATA_DISTRIBUTION"] == "iid":
        lengths = [len(trainset) // m] * m
        trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
    else:
        dirichelet_splits, _ = parse_splits_csv(f"cifar10/federated_train_alpha_{alpha:.2f}.csv")
        trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]

    # Prepare the validation set
    lengths = [len(testset) // m] * m
    testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)

    # Instantiate the clients
    clients = [
        Client(id, trainset, testset, use_data_preloading, online_train_transform)
        for id, (trainset, testset) in enumerate(zip(trainsets, testsets))
    ]

    # Initialize h
    h = {
        key: torch.zeros(params.shape, device=device)
        for key, params in net.state_dict().items()
    }

    # best model
    best_model = {}
    best_accuracy = 0.0   

    for step in tqdm(range(config["MAX_TIME"])):
        selected_clients = random.sample(clients, K)
        #print(f"Client(s) {[client.i for client in selected_clients]} selected")

        # Collect the updates from the clients
        thetas = [client.update(feddyn_alpha, net.state_dict()) for client in selected_clients]

        h = {
            key: prev_h
            - feddyn_alpha * 1 / m * sum(theta[key] - old_params for theta in thetas)
            for (key, prev_h), old_params in zip(h.items(), net.state_dict().values())
        }

        new_parameters = {
            key: (1 / K) * sum(theta[key] for theta in thetas)
            for key in net.state_dict().keys()
        }

        new_parameters = {
            key: params - (1 / feddyn_alpha) * h_params
            for (key, params), h_params in zip(new_parameters.items(), h.values())
        }

        net.load_state_dict(new_parameters)

        # Compute the average accuracy
        if step % config["LOG_FREQUENCY"] == 0:
            # Evaluate the current server parameters against the validation sets
            client_losses_accuracies = [client.compute_accuracy(new_parameters) for client in clients]
            client_losses, client_accuracies = zip(*client_losses_accuracies)

            # Average accuracy across the clients
            avg_client_accuracy = mean(client_acc for client_acc in client_accuracies)
            accuracies[alpha].append(avg_client_accuracy * 100)
            print(f"Average accuracy after {step} rounds is {avg_client_accuracy*100}")

            # Periodically save the computations done so far
            if step % config["SAVE_FREQUENCY"] == 0:
                printJSON(alpha, accuracies[alpha], net, step)

            # Save the model with the best accuracy
            if avg_client_accuracy >= best_accuracy:
                best_accuracy = avg_client_accuracy
                best_model = deepcopy(new_parameters)

    # Print the final average accuracy and save the final model  
    avg_accuracy = mean(float(client.compute_accuracy(best_model)[1]) for client in clients)
  
    config["AVERAGE_ACCURACY"][alpha_i] = avg_accuracy
    print(f"Average accuracy with alpha = {alpha} after {step+1} rounds is {avg_accuracy*100}")
    printJSON(alpha, accuracies[alpha], net)

## Accuracy plots

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

ax = plt.figure().gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

x = np.arange(0, config["MAX_TIME"], config["LOG_FREQUENCY"])
for alpha, values in accuracies.items():
  plt.plot(x, values, label=f"{alpha}")

plt.xlabel("Number of rounds")
plt.ylabel("Accuracy")
plt.title("Accuracy vs number of rounds")
plt.legend(title="Dirichlet alpha")
plt.savefig("FedDyn_non_iid.png", dpi=300)
plt.show()

## Artifacts

In [None]:
import shutil

zip_name = "artifacts"
if "/" in artifacts_target_directory:
    zip_name = artifacts_target_directory.split("/")[-1]

shutil.make_archive(zip_name + "-final", "zip", f"{artifacts_target_directory}/final")
shutil.make_archive(zip_name + "-partials", "zip", f"{artifacts_target_directory}/partials")

# Display a link to the artifacts zip (useful on Kaggle and Colab)
from IPython.display import FileLink
FileLink(f"{zip_name}-final.zip")
FileLink(f"{zip_name}-partials.zip")

In [None]:
shutil.copy(f"{zip_name}-final.zip", artifacts_target_directory)
shutil.copy(f"{zip_name}-partials.zip", artifacts_target_directory)