# Baseline implementation

In [None]:
# download the Cifar10 non-iid splits, if not present

from os import path
import urllib.request
import zipfile

if not path.exists("cifar10"):
    save_path = "cifar10.zip"
    urllib.request.urlretrieve("http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10_v1.1.zip", save_path)
    
    with zipfile.ZipFile(save_path, 'r') as zip_ref:
        zip_ref.extractall("cifar10")

In [None]:
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
config = {
    "E": 1, # number of local epochs
    "K": 5, # number of clients selected each round # [5, 10, 20]
    "NUMBER_OF_CLIENTS": 100, # total number of clients
    "MAX_TIME": 5000,
    "BATCH_SIZE": 50,
    "VALIDATION_BATCH_SIZE": 500,
    "LR": 0.01,
    "WEIGHT_DECAY": 4e-4,
    "DATA_DISTRIBUTION": "non-iid", # "iid" | "non-iid"
    "DIRICHELET_ALPHA": [0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0],
    "AVERAGE_ACCURACY": np.zeros(8),
    "FED_AVG_M": False,
    "FED_AVG_M_BETA": 0.9,
    "FED_AVG_M_GAMMA": 1,
    "LR_DECAY": 0.99,
    "LOG_FREQUENCY": 25,
    "AUGMENTATION_PROB": 0.0,
    "SAVE_FREQUENCY": 100,
    "NORM_LAYER": "", # batch normalization
}

use_data_preloading = True

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, 5)
        if config["NORM_LAYER"] == "bn":
            self.norm1 = nn.BatchNorm2d(64)
        elif config["NORM_LAYER"] == "gn":
            self.norm1 = nn.GroupNorm(4, 64)
        
        self.conv2 = nn.Conv2d(64, 64, 5)
        if config["NORM_LAYER"] == "bn":
            self.norm2 = nn.BatchNorm2d(64)
        elif config["NORM_LAYER"] == "gn":
            self.norm2 = nn.GroupNorm(4, 64)
        
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        if config["NORM_LAYER"] in ['bn', 'gn']:
            x = F.max_pool2d(F.relu(self.norm1(self.conv1(x))), (2,2))
            x = F.max_pool2d(F.relu(self.norm2(self.conv2(x))), 2)
        else:
            x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
#print(net)

In [None]:
import torch.optim as optim

class Client:
    def __init__(self, i, train_set, validation_set, *, input_size=32, use_data_preloading, train_transform):
        self.i = i
        self.net = Net()
        self.net = self.net.to(device)
        
        # Create the validation loader
        self.validation_loader = torch.utils.data.DataLoader(validation_set,
            batch_size=len(validation_set),
            shuffle=False, num_workers=0
        )
        
        # Create an optimizer for the model's parameters
        self.optimizer = optim.SGD(self.net.parameters(), lr=config["LR"], weight_decay=config["WEIGHT_DECAY"])
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        
        self.train_transform = train_transform

        self.use_data_preloading = use_data_preloading
        if self.use_data_preloading:
            # preloading train and validation data
            self.train_loader = torch.utils.data.DataLoader(train_set,
                batch_size=len(train_set), shuffle=True, 
                num_workers=0, pin_memory=True
            )
            
            # preload the training images
            training_images, training_labels = next(iter(self.train_loader))
            self.training_images = training_images.to(device)
            self.training_labels = training_labels.to(device)
            
            # preload the validation images
            validation_images, validation_labels = next(iter(self.validation_loader))
            self.validation_images = validation_images.to(device)
            self.validation_labels = validation_labels.to(device)
        else:
            self.train_loader = torch.utils.data.DataLoader(train_set, 
                batch_size=config["BATCH_SIZE"], shuffle=True, 
                num_workers=0, pin_memory=True
            )

    def clientUpdate(self, lr, parameters):
        self.net.load_state_dict(parameters)
        self.net.train()

        for g in self.optimizer.param_groups:
            g["lr"] = lr

        for _ in range(config["E"]):
            epoch_loss, n = 0, 0
            for images, labels in self.iter_training_data():
                # in your training loop:
                self.optimizer.zero_grad()  # zero the gradient buffers
                outputs = self.net(images)
                loss = self.criterion(outputs, labels)
                epoch_loss += loss
                n += labels.size(0)

                loss = loss / labels.size(0)
                loss.backward()
                # wandb.log({f"client-loss-{self.i}": loss.item()})
                self.optimizer.step()  # Does the update
            epoch_loss = epoch_loss / n

        return_dict = {}
        for (k1, v1), (k2, v2) in zip(parameters.items(), self.net.state_dict().items()):
            return_dict[k1] = v1 - v2
        return epoch_loss, return_dict

    def iter_training_data(self):
        batch_size = config["BATCH_SIZE"]

        if self.use_data_preloading:
            # shuffle the training data
            indices = torch.randperm(self.training_images.size(0))
            training_images, training_labels = self.training_images[indices], self.training_labels[indices]
            
            # possibly apply the training transformation
            if self.train_transform is not None:
                training_images = torch.stack([self.train_transform(im) for im in training_images])
            
            yield from zip(torch.split(training_images, batch_size), torch.split(training_labels, batch_size))
        else:
            yield from self.train_loader
                
    def iter_validation_data(self):
        batch_size = config["VALIDATION_BATCH_SIZE"]

        if self.use_data_preloading:
            yield from zip(torch.split(self.validation_images, batch_size), torch.split(self.validation_labels, batch_size))
        else:
            yield from self.validation_loader

    def compute_accuracy(self, parameters):
        self.net.load_state_dict(parameters)
        self.net.eval()

        running_corrects = 0
        loss, n = 0, 0
        for data, labels in self.iter_validation_data():
            with torch.no_grad():
                outputs = self.net(data)
            loss += self.criterion(outputs, labels).item()

            _, preds = torch.max(outputs.data, 1)

            running_corrects += torch.sum(preds == labels.data).data.item()
            n += len(preds)

        return loss / n, running_corrects / n

In [None]:
from collections import defaultdict

def parse_csv(filename):
  splits = defaultdict(lambda: [])
  labels_mapping = dict()

  with open(filename) as f:
    for line in f:
      if not line[0].isdigit():
        continue

      user_id, image_id, label = (int(token) for token in line.split(","))
      splits[user_id].append(image_id)
      labels_mapping[image_id] = label

  return splits, labels_mapping

In [None]:
import time
import json
import numpy
from copy import deepcopy

def listToString(l): 
    return " ".join(str(l))

def printJSON(alpha, acc, net, step = None):
    artifacts_dir = "artifacts"

    artifact_filename = f"ALPHA_{alpha}_E_{config['E']}_K_{config['K']}"
    if step is not None:
      artifact_filename += f"_STEPS_{step}"
    
    if config["AUGMENTATION_PROB"] > 0:
      artifact_filename += f"_T"

    artifact_filename += f"_{config['NORM_LAYER'].upper()}" if config['NORM_LAYER'] else ""
      
    # parameters of the trained model
    server_model = net.state_dict()
    # save the model on the local file system
    torch.save(server_model, f"{artifacts_dir}/{artifact_filename}.pth")
    config_copy = deepcopy(config)
    config_copy["DIRICHELET_ALPHA"] = listToString(config_copy["DIRICHELET_ALPHA"])
    config_copy["AVERAGE_ACCURACY"] = numpy.array2string(config_copy["AVERAGE_ACCURACY"])
    data = {
        "config": config_copy,
        "alpha": listToString(alpha),
        "accuracy": acc
    }

    with open(f"{artifacts_dir}/{artifact_filename}.json", "w") as f:
        f.write(json.dumps(data, indent=4))

    # If you want to cat the file, my suggestion is to avoid this is a pretty heavy operation at least on my pc
    #artifact_filename += ".json"
    #!cat artifact_filename


In [None]:
def selectClients(k):
  return random.sample(clients, k=k)

def aggregateClient(deltaThetas):
  parameters = None
  for i,d in enumerate(deltaThetas):
    #ratio = len(trainsets[i])/len(trainset)
    ratio = len(trainsets[i])/(len(trainsets[i])*config['K'])
    
    if i == 0:
      parameters = {k:ratio*v for k, v in d.items()}
    else:
      for (k, v) in d.items():
        parameters[k] += ratio * v
   
  return parameters

In [None]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
from statistics import mean

from tqdm.notebook import tqdm

import os

random.seed(42)

# Random transformations that provide data augmentation
random_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(1),
        transforms.ColorJitter(0.9, 0.9)
    ]
)
# Normalization values for the CIFAR10 dataset
normalization_transform = transforms.Normalize(
    mean=[0.491, 0.482, 0.447], 
    std=[0.247, 0.243, 0.262]
)

# Transformation strategy:
#  1. apply as many transformations as possibile offline (at trainset level)
#  2. apply the remaining transformations online (at client level, before
#     iterating over the data)
#  
# Random transformations must be applied online.

if use_data_preloading:
    # data preloading is enabled
    
    if config["AUGMENTATION_PROB"] > 1e-5:
        # non-zero augmentation probability => apply transformations online
        offline_train_transform = None
        online_train_transform = transforms.Compose([
            transforms.RandomApply([random_transform], config["AUGMENTATION_PROB"]),
            normalization_transform
        ])
    else:
        # zero augmentation probability => apply transformations offline
        offline_train_transform = normalization_transform
        online_train_transform = None
else:
    # augmentation probability is zero => all transformations can be applied offline
    offline_train_transform = transforms.Compose([
        transforms.RandomApply([random_transform], config["AUGMENTATION_PROB"]),
        normalization_transform
    ])
    online_train_transform = None
    

trainset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transforms.Compose([transforms.ToTensor(), offline_train_transform])
)

testset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transforms.Compose([transforms.ToTensor(), normalization_transform])
  )

if not path.exists("artifacts"):
  os.mkdir("artifacts")

In [None]:
# verify the labels specified in the .csv files are coherent with the actual CIFAR-10 labels
# see https://github.com/google-research/google-research/issues/924

_, labels_mapping = parse_csv(f"cifar10/federated_train_alpha_{0.0:.2f}.csv")
assert(all(label == labels_mapping[idx] for idx, label in enumerate(trainset.targets)))

In [None]:
for alpha_i, alpha in enumerate(config["DIRICHELET_ALPHA"]):
  net = Net()
  net = net.to(device)

  optimizer = optim.SGD(net.parameters(), lr=config["LR"], momentum=0.9, weight_decay=1e-3)
  scheduler = ReduceLROnPlateau(optimizer, "min", factor=0.5, min_lr=1e-6, verbose=True)

  if config["DATA_DISTRIBUTION"] == "iid":
    # split the training set
    trainset_len = ( len(trainset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
    trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

    lengths = len(trainset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
    trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
  else:
    dirichelet_splits, _ = parse_csv(f"cifar10/federated_train_alpha_{alpha:.2f}.csv")
    trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]


  # split the validation set
  testset_len = ( len(testset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
  testset = torch.utils.data.Subset(testset, list(range(testset_len)))

  lengths = len(testset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
  testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)


  clientsSizes = torch.zeros(config["NUMBER_OF_CLIENTS"])
  clients = list()

  for c in range(config["NUMBER_OF_CLIENTS"]):
    clients.append(Client(c, trainsets[c], testsets[c], use_data_preloading=use_data_preloading, train_transform=online_train_transform))

  if config["FED_AVG_M"]:
    old_parameters = {}

  # collect the test accuracies over the epochs
  test_accuracies = []

  accuracies = list()

  # best model
  best_model = {}
  best_accuracy = 0.0

  for step in tqdm(range(config["MAX_TIME"])):
    selected_clients = selectClients(config["K"])
    #print(f"Client(s) {[client.i for client in selected_clients]} selected")

    deltaThetas = list()
    losses = list()
    for i, c in enumerate(selected_clients):
      loss, parameters = c.clientUpdate(optimizer.param_groups[0]['lr'], net.state_dict())
      deltaThetas.append(parameters)
      losses.append(loss)

    g = aggregateClient(deltaThetas)
    
    parameters = {}
    for (k1, v1), (k2, v2) in zip(net.state_dict().items(), g.items()):
      
      if config["FED_AVG_M"]:
        if k1 in old_parameters:
          parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * (config["FED_AVG_M_BETA"] * old_parameters[k1] + v2)  
          old_parameters[k1] = config["FED_AVG_M_BETA"] * old_parameters[k1] + v2
        else:
          parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * v2
          old_parameters[k1] = v2
      else:
        parameters[k1] = v1 - v2 # todo: add server learning rate gamma

    # compute loss and accuracy on the test set of the clients
    # client.compute_accuracy(parameters) returns tuples (loss, accuracy)
    # client_losses_accuracies = [client.compute_accuracy(parameters) for client in clients]
    # client_losses, client_accuracies = zip(*client_losses_accuracies)

    # compute the average client loss
    # and feed it to the scheduler
    # avg_client_loss = mean(client_loss for client_loss in client_losses)
    # scheduler.step(avg_client_loss)

    # compute the average accuracy
    if step % config["LOG_FREQUENCY"] == 0:
      client_losses_accuracies = [client.compute_accuracy(parameters) for client in clients]
      client_losses, client_accuracies = zip(*client_losses_accuracies)
      
      avg_client_accuracy = mean(client_acc for client_acc in client_accuracies)
      accuracies.append(avg_client_accuracy * 100)
      
      if avg_client_accuracy >= best_accuracy:
        best_accuracy = avg_client_accuracy
        best_model = deepcopy(parameters) # net.state_dict()
          
      print(f"Average accuracy after {step} rounds is {avg_client_accuracy*100:.2f}")    

    net.load_state_dict(parameters)

    if step % config["SAVE_FREQUENCY"] == 0:
      printJSON(alpha, accuracies, net, step)
  
  avg_accuracy = mean(float(client.compute_accuracy(best_model)[1]) for client in clients)
  #model_parameters = net.state_dict()
  #avg_accuracy = mean(float(client.compute_accuracy(model_parameters)[1]) for client in clients)
 
  #alpha = config["DIRICHELET_ALPHA"][i]
  config["AVERAGE_ACCURACY"][alpha_i] = avg_accuracy
  print(f"Average accuracy with alpha = {alpha} after {step+1} rounds is {avg_accuracy*100:.2f}")
  printJSON(alpha, accuracies, net)


In [None]:
import shutil
shutil.make_archive('artifacts', 'zip', 'artifacts')