# Federated Learning based on Federated Averaging

In [None]:
from os import path
import urllib.request
import zipfile

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from collections import defaultdict

import json
from copy import deepcopy

import torchvision
import torchvision.transforms as transforms
import random
from statistics import mean

from tqdm.notebook import tqdm

import os
random.seed(42)

## Configuration

In [None]:
config = {
    "E": 1, # number of local epochs
    "K": 5, # number of clients selected each round # [5, 10, 20]
    "NUMBER_OF_CLIENTS": 100, # total number of clients
    "MAX_TIME": 5000, # number of rounds
    "BATCH_SIZE": 50,
    "VALIDATION_BATCH_SIZE": 500,
    "LR": 0.01, # learning rate
    "WEIGHT_DECAY": 4e-4,
    "DATA_DISTRIBUTION": "non-iid", # "iid" | "non-iid"
    "DIRICHELET_ALPHA": [0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0],
    "AVERAGE_ACCURACY": np.zeros(8),
    "FED_AVG_M": False,
    "FED_AVG_M_BETA": 0.9,
    "FED_AVG_M_GAMMA": 1,
    "LR_DECAY": 0.99,
    "LOG_FREQUENCY": 25, # frequency of logs printed
    "AUGMENTATION_PROB": 0.0, # for data transformation, see below (Prepare dataset cells)
    "SAVE_FREQUENCY": 100, # frequency of logs saved
    "NORM_LAYER": "", # Normalization layer [None: "", Batch: "bn", Group: "gn"]
}

use_data_preloading = True

## CIFAR-10

In [None]:
# download the Cifar10 non-iid splits, if not present
if not path.exists("cifar10"):
    save_path = "cifar10.zip"
    urllib.request.urlretrieve("http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10_v1.1.zip", save_path)
    
    with zipfile.ZipFile(save_path, 'r') as zip_ref:
        zip_ref.extractall("cifar10")

device = "cuda" if torch.cuda.is_available() else "cpu"

## Prepare dataset

In [None]:
# Random transformations to provide data augmentation
random_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(1),
        transforms.ColorJitter(0.9, 0.9)
    ]
)
# Normalization values for the CIFAR10 dataset
normalization_transform = transforms.Normalize(
    mean=[0.491, 0.482, 0.447], 
    std=[0.247, 0.243, 0.262]
)

# Transformation strategy:
#  1. apply as many transformations as possibile offline (at trainset level)
#  2. apply the remaining transformations online (at client level, before
#     iterating over the data)
#  
# Random transformations must be applied online.

if use_data_preloading:
    # Data preloading is enabled
    
    if config["AUGMENTATION_PROB"] > 1e-5:
        # Non-zero augmentation probability => apply transformations online
        offline_train_transform = None
        online_train_transform = transforms.Compose([
            transforms.RandomApply([random_transform], config["AUGMENTATION_PROB"]),
            normalization_transform
        ])
    else:
        # Zero augmentation probability => apply transformations offline
        offline_train_transform = normalization_transform
        online_train_transform = None
else:
    # Augmentation probability is zero => all transformations can be applied offline
    offline_train_transform = transforms.Compose([
        transforms.RandomApply([random_transform], config["AUGMENTATION_PROB"]),
        normalization_transform
    ])
    online_train_transform = None
    

trainset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transforms.Compose([transforms.ToTensor(), offline_train_transform])
)

testset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transforms.Compose([transforms.ToTensor(), normalization_transform])
  )

## LeNet-5

In [None]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, 5)
        if config["NORM_LAYER"] == "bn": # if batch normalization
            self.norm1 = nn.BatchNorm2d(64)
        elif config["NORM_LAYER"] == "gn": # if group normalization
            self.norm1 = nn.GroupNorm(4, 64)
        
        self.conv2 = nn.Conv2d(64, 64, 5)
        if config["NORM_LAYER"] == "bn":
            self.norm2 = nn.BatchNorm2d(64)
        elif config["NORM_LAYER"] == "gn":
            self.norm2 = nn.GroupNorm(4, 64)
        
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        if config["NORM_LAYER"] in ['bn', 'gn']:
            x = F.max_pool2d(F.relu(self.norm1(self.conv1(x))), (2,2))
            x = F.max_pool2d(F.relu(self.norm2(self.conv2(x))), 2)
        else:
            x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        
        x = torch.flatten(x, 1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Client

In [None]:
class Client:
    def __init__(self, i, train_set, validation_set, *, input_size=32, use_data_preloading, train_transform):
        """Instantiate a new client

        The parameter `use_data_preloading` allows to indicate whether the client should
        preload all its training and validation data before starting the training loop. This
        is preferred since it drastically speeds up the training proccess. Beware that, depending
        on the dataset, the memory usage may quickly becoming unfeasible. 100 clients, each
        with its own data and network, requires ~2.5 GB of gpu memory.
        `train_transform` is a pytorch transform to be applied on each sample at training time.
        """
        self.i = i
        self.net = Net()
        self.net = self.net.to(device)
        
        # Create the dataloader for the validation data
        self.validation_loader = torch.utils.data.DataLoader(validation_set,
            batch_size=len(validation_set),
            shuffle=False, num_workers=0
        )
        
        # Create an optimizer for the model's parameters
        self.optimizer = optim.SGD(self.net.parameters(), lr=config["LR"], weight_decay=config["WEIGHT_DECAY"])
        # Create the loss criterion
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        
        self.train_transform = train_transform

        self.use_data_preloading = use_data_preloading
        if self.use_data_preloading:
            # Preloading train (the batch size config is ignored here)
            self.train_loader = torch.utils.data.DataLoader(train_set,
                batch_size=len(train_set), shuffle=True, 
                num_workers=0, pin_memory=True
            )
            
            # Read all the dataset at once...
            training_images, training_labels = next(iter(self.train_loader))
            # ...and transfer the data on the target device
            self.training_images = training_images.to(device)
            self.training_labels = training_labels.to(device)
            
            # Preload the validation images
            validation_images, validation_labels = next(iter(self.validation_loader))
             # ...and transfer the data on the target device
            self.validation_images = validation_images.to(device)
            self.validation_labels = validation_labels.to(device)
        else:
            # Preloading was not requested => a dataloader is initialized as usual
            self.train_loader = torch.utils.data.DataLoader(train_set, 
                batch_size=config["BATCH_SIZE"], shuffle=True, 
                num_workers=0, pin_memory=True
            )

    def clientUpdate(self, parameters):
        """Run a step of client training

        Parameters
        ----------
        parameters : OrderedDict[str, torch.Tensor]
            Initial parameters of the client model

        Returns
        -------
            epoch loss,
        OrderedDict[str, torch.Tensor]
            new parameters of the client at the end of the training step
        """
        # Restore the parameters of the network from the server model
        self.net.load_state_dict(parameters)
        self.net.train()

        for _ in range(config["E"]):
            # Compute the total epoch loss
            epoch_loss, n = 0, 0
            for images, labels in self.iter_training_data():
                # Set gradients to zero
                self.optimizer.zero_grad()  
                outputs = self.net(images)

                # Compute the loss of the model
                loss = self.criterion(outputs, labels)
                epoch_loss += loss

                n += labels.size(0)
                loss = loss / labels.size(0)
                # Backward step
                loss.backward()
                self.optimizer.step()  

            epoch_loss = epoch_loss / n
        
        # Computing the returning dict
        return_dict = {}
        for (k1, v1), (_, v2) in zip(parameters.items(), self.net.state_dict().items()):
            return_dict[k1] = v1 - v2
        return epoch_loss, return_dict

    def iter_training_data(self):
        """Iterate over the training data of the client

        Yields
        -------
        Generator[torch.Tensor]
            the training dataset, split in mini-batches
        """
        batch_size = config["BATCH_SIZE"]

        if self.use_data_preloading:
            # Data already preloaded => shuffle and return from cache
            indices = torch.randperm(self.training_images.size(0))
            training_images, training_labels = self.training_images[indices], self.training_labels[indices]
            
            # Possibly apply the required training transformation
            if self.train_transform is not None:
                training_images = torch.stack([self.train_transform(im) for im in training_images])
            
            # Yield the training set
            yield from zip(torch.split(training_images, batch_size), torch.split(training_labels, batch_size))
        else:
            # Data not already preloaded => yield from the dataloader
            for images, labels in self.train_loader:
                images = images.to(device)
                labels = labels.to(device)
                yield images, labels
                
    def iter_validation_data(self):
        """Iterate over the validation data of the the client

        Yields
        -------
        Generator[torch.Tensor]
            the validation dataset, split in mini-batches
        """
        batch_size = config["VALIDATION_BATCH_SIZE"]

        if self.use_data_preloading:
            # Data already preloaded => return from cache
            yield from zip(torch.split(self.validation_images, batch_size), torch.split(self.validation_labels, batch_size))
        else:
            # Data not already preloaded => yield from the dataloader
            for images, labels in self.validation_loader:
                images = images.to(device)
                labels = labels.to(device)
                yield images, labels
           
    def compute_accuracy(self, parameters):
        """Compute the accuracy of the client

        Parameters
        ----------
        parameters : OrderedDict[str, torch.Tensor]
            parameters of the client

        Returns
        -------
        Tuple[float, float]
            average loss on the validation set, accuracy on the validation set
        """
        self.net.load_state_dict(parameters)
        # Set the model in evaluation mode
        self.net.eval()

        running_corrects = 0
        loss, n = 0, 0
        for data, labels in self.iter_validation_data():
            with torch.no_grad():
                # Compute the network outputs without gradients
                outputs = self.net(data)
            # Compute the validation 
            loss += self.criterion(outputs, labels).item()

            # Count the number of correct predictions
            _, preds = torch.max(outputs.data, 1)

            running_corrects += torch.sum(preds == labels.data).data.item()
            n += len(preds)

        return loss / n, running_corrects / n

In [None]:
def parse_csv(filename):
    """Read a CIFAR-10 splits file

    Parameters
    ----------
    filename : str
        path of the .csv file containing the splits

    Returns
    -------
    Tuple[DefaultDict[int, List[int]], Dict[int, int]]
        the dictionary containing the splits as user_id:[image_id]
        and the labels_mapping as image_id:label
    """
    splits = defaultdict(lambda: [])
    labels_mapping = dict()

    with open(filename) as f:
        for line in f:
            if not line[0].isdigit():
                # Skip the first line
                continue

            user_id, image_id, label = (int(token) for token in line.split(","))
            splits[user_id].append(image_id)
            labels_mapping[image_id] = label

    return splits, labels_mapping

In [None]:
def printJSON(alpha, acc, net, step = None):
    """Create the json artifacts file

    Parameters
    ----------
    alpha: float
        value of alpha
    acc: list
        list of accuracies at different iterations
    net: the actual network configuration
    step: int
        current value of iteration
    """
    artifacts_dir = "artifacts"

    artifact_filename = f"ALPHA_{alpha}_E_{config['E']}_K_{config['K']}"
    if step is not None:
      artifact_filename += f"_STEPS_{step}"
    
    if config["AUGMENTATION_PROB"] > 0:
      artifact_filename += f"_T"

    artifact_filename += f"_{config['NORM_LAYER'].upper()}" if config['NORM_LAYER'] else ""
      
    # Parameters of the trained model
    server_model = net.state_dict()
    # Save the model on the local file system
    torch.save(server_model, f"{artifacts_dir}/{artifact_filename}.pth")
    config_copy = deepcopy(config)

    # remove DIRICHELET_ALPHA and AVERAGE_ACCURACY (no need to log them)
    del config_copy["DIRICHELET_ALPHA"]
    del config_copy["AVERAGE_ACCURACY"]
    data = {
        "config": config_copy,
        "alpha": alpha,
        "accuracy": acc
    }

    with open(f"{artifacts_dir}/{artifact_filename}.json", "w") as f:
        f.write(json.dumps(data, indent=4))

## Client selection and aggregation

In [None]:
def selectClients(k):
  """
  It select k random clients

  Parameters
  ----------
  k : int
    number of client to sample

  Returns
  -------
  It return the list of selected clients.
  """
  return random.sample(clients, k=k)

def aggregateClient(deltaThetas):
  """
  Compute the aggregation of parameters of different clients
  Parameters
  ----------
  deltaThetas: array of parameters of different clients.

  Returns
  -------
  A dict with updated parameters.
  """
  parameters = None
  for i,d in enumerate(deltaThetas):
    #ratio = len(trainsets[i])/len(trainset)
    ratio = len(trainsets[i])/(len(trainsets[i])*config['K'])
    
    if i == 0:
      parameters = {k:ratio*v for k, v in d.items()}
    else:
      for (k, v) in d.items():
        parameters[k] += ratio * v
   
  return parameters

## Training

In [None]:
# Create the artifacts folder for save the artifacts
if not path.exists("artifacts"):
  os.mkdir("artifacts")

In [None]:
# Verify the labels specified in the .csv files are coherent with the actual CIFAR-10 labels
# see https://github.com/google-research/google-research/issues/924

_, labels_mapping = parse_csv(f"cifar10/federated_train_alpha_{0.0:.2f}.csv")
assert(all(label == labels_mapping[idx] for idx, label in enumerate(trainset.targets)))

In [None]:
for alpha_i, alpha in enumerate(config["DIRICHELET_ALPHA"]):
  # Create a dummy model that will hold the server model parameters
  net = Net()
  net = net.to(device)

  optimizer = optim.SGD(net.parameters(), lr=config["LR"], momentum=0.9, weight_decay=1e-3)
  scheduler = ReduceLROnPlateau(optimizer, "min", factor=0.5, min_lr=1e-6, verbose=True)

  # Prepare the training set
  if config["DATA_DISTRIBUTION"] == "iid":
    trainset_len = ( len(trainset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
    trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

    lengths = len(trainset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
    trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
  else:
    dirichelet_splits, _ = parse_csv(f"cifar10/federated_train_alpha_{alpha:.2f}.csv")
    trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]


  # Prepare the validation set
  testset_len = ( len(testset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
  testset = torch.utils.data.Subset(testset, list(range(testset_len)))

  lengths = len(testset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
  testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)

  # Instantiate the clients
  clientsSizes = torch.zeros(config["NUMBER_OF_CLIENTS"])
  clients = list()

  for c in range(config["NUMBER_OF_CLIENTS"]):
    clients.append(Client(c, trainsets[c], testsets[c], use_data_preloading=use_data_preloading, train_transform=online_train_transform))

  if config["FED_AVG_M"]:
    old_parameters = {}

  # Collect the test accuracies over the epochs
  test_accuracies = []

  accuracies = list()

  # best model
  best_model = {}
  best_accuracy = 0.0

  for step in tqdm(range(config["MAX_TIME"])):
    selected_clients = selectClients(config["K"])
    #print(f"Client(s) {[client.i for client in selected_clients]} selected")

    # Collect the updates from the clients
    deltaThetas = list()
    losses = list()
    for i, c in enumerate(selected_clients):
      loss, parameters = c.clientUpdate(net.state_dict())
      deltaThetas.append(parameters)
      losses.append(loss)

    g = aggregateClient(deltaThetas)
    
    # Compute the parameters update
    parameters = {}
    for (k1, v1), (k2, v2) in zip(net.state_dict().items(), g.items()):
      if config["FED_AVG_M"]:
        if k1 in old_parameters:
          parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * (config["FED_AVG_M_BETA"] * old_parameters[k1] + v2)  
          old_parameters[k1] = config["FED_AVG_M_BETA"] * old_parameters[k1] + v2
        else:
          parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * v2
          old_parameters[k1] = v2
      else:
        parameters[k1] = v1 - v2 

    # Compute the average accuracy
    if step % config["LOG_FREQUENCY"] == 0:
      # Evaluate the current server parameters against the validation sets
      client_losses_accuracies = [client.compute_accuracy(parameters) for client in clients]
      client_losses, client_accuracies = zip(*client_losses_accuracies)
      # Average accuracy across the clients
      avg_client_accuracy = mean(client_acc for client_acc in client_accuracies)
      accuracies.append(avg_client_accuracy * 100)
      
      # Periodically save the computations done so far
      if step % config["SAVE_FREQUENCY"] == 0:
        printJSON(alpha, accuracies, net, step)

      # Save the model with the best accuracy
      if avg_client_accuracy >= best_accuracy:
        best_accuracy = avg_client_accuracy
        best_model = deepcopy(parameters) # net.state_dict()
    # Udate the parameters in the net
    net.load_state_dict(parameters)
  
  # Print the final average accuracy and save the final model  
  avg_accuracy = mean(float(client.compute_accuracy(best_model)[1]) for client in clients)
  
  config["AVERAGE_ACCURACY"][alpha_i] = avg_accuracy
  print(f"Average accuracy with alpha = {alpha} after {step+1} rounds is {avg_accuracy*100:.2f}")
  printJSON(alpha, accuracies, net)


## Artifacts

In [None]:
import shutil
shutil.make_archive('artifacts', 'zip', 'artifacts')