**Packages, Libraries and others**

In [None]:
# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Torchvision
import torchvision
from torchvision import transforms as T
from torchvision import datasets

# Import Numpy
import numpy as np

# Import Pyplot
import matplotlib.pyplot as plt

# Import train_test_split
from sklearn.model_selection import train_test_split

# Set device
import copy
import random
from collections import OrderedDict
from math import pi

**Client and Servers implementations**

In [None]:
# CLIENT
class Client:

  def __init__(self, client_id, dataset, batch_size, model):
    self.client_id = client_id
    self.dataset = dataset
    self.model = model
    self.criterion = nn.CrossEntropyLoss()
    self.dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True)
    self.num_samples = len(self.dataloader) * batch_size

  def run_epoch(self, optimizer):
    cumulative_loss = 0
    for images, labels in self.dataloader:
      optimizer.zero_grad()
      labels_hat = self.model.forward(images)
      loss = self.criterion(labels_hat, labels)
      loss.backward()
      optimizer.step()
      cumulative_loss += loss.item()
    return cumulative_loss / self.num_samples

  def train(self, num_epochs):
    optimizer = torch.optim.Adam(self.model.parameters(),
                                 lr = 1e-3, weight_decay = 1e-4)
    self.model.train()
    loss_track = torch.empty(num_epochs)
    for epoch in range(num_epochs):
      loss_track[epoch] = self.run_epoch(optimizer)
    return loss_track

  def test(self):
    cumulative_loss = 0
    cumulative_accuracy = 0
    num_samples = 0
    self.model.eval()
    with torch.no_grad():
      for images, labels in self.dataloader:
        labels_hat = self.model.forward(images)
        labels_pred = labels_hat.argmax(dim = 1)
        cumulative_loss += self.criterion(labels_hat, labels).item()
        cumulative_accuracy += labels_pred.eq(labels).sum().item()
        num_samples += images.shape[0]
    return cumulative_loss / num_samples, cumulative_accuracy / num_samples

  def generate_update(self):
    return copy.deepcopy(self.model.state_dict())

In [None]:
# SERVER AVG
class Server_fedavg:

  def __init__(self, train_clients, test_clients, model,
               num_rounds, epochs_per_round, num_clients_per_round, random_state = 1):
    self.train_clients = train_clients
    self.test_clients = test_clients
    self.model = model
    self.model_params_dict = self.model.state_dict()
    self.num_rounds = num_rounds
    self.epochs_per_round = epochs_per_round
    self.num_clients_per_round = num_clients_per_round
    self.weights = {client.client_id: 1 / len(self.train_clients) \
                    for client in self.train_clients}
    self.updates = []
    self.prng = np.random.default_rng(random_state)

  def load_model_on_clients(self):
    for client in self.train_clients + self.test_clients:
      client.model.load_state_dict(self.model_params_dict, strict = False)

  def select_clients(self):
    return self.prng.choice(
        self.train_clients,
        size = self.num_clients_per_round,
        replace = False,
        )

  def train_round(self, train_clients):
    client_loss = dict()
    for client in train_clients:
      client_loss[client.client_id] = client.train(self.epochs_per_round)
      self.updates.append((client.client_id, client.generate_update()))
    return client_loss

  def update_model(self):
    base = OrderedDict()
    for client_id, client_model in self.updates:
      for key, value in client_model.items():
        if key in base:
          base[key] += value.type(torch.FloatTensor) / self.num_clients_per_round
        else:
          base[key] = value.type(torch.FloatTensor) / self.num_clients_per_round
      for key, value in base.items():
        self.model_params_dict[key] = value
    self.model.load_state_dict(self.model_params_dict, strict = False)
    self.load_model_on_clients()

  def train(self):
    results = []
    for round in range(self.num_rounds):
      train_clients = self.select_clients()
      client_loss = self.train_round(train_clients)
      results.append(client_loss)
      self.update_model()
      self.updates = []
    return results

  def train_evaluation(self):
    results = {
        'Train': [],
        'Test': []
    }
    for round in range(self.num_rounds):
      print(f'Round {round + 1}')
      train_clients = self.select_clients()
      client_loss = self.train_round(train_clients)
      self.update_model()
      self.updates = []
      # Compute mean accuracy on train set
      acc = 0
      stats = self.eval_train()
      for client, res in stats.items():
        acc += res['Accuracy'] / len(self.train_clients)
      results['Train'].append(acc)
      # Compute mean accuracy on test set
      acc = 0
      stats = self.test()
      for client, res in stats.items():
        acc += res['Accuracy'] / len(self.test_clients)
      results['Test'].append(acc)
    return results

  def eval_train(self):
      self.load_model_on_clients()
      eval_statistics = {c.client_id: {} for c in self.train_clients}
      for c in self.train_clients:
          l, m = c.test()
          eval_statistics[c.client_id]["Loss"] = l
          eval_statistics[c.client_id]["Accuracy"] = m
      return eval_statistics

  def test(self):
      self.load_model_on_clients()
      eval_statistics = {c.client_id: {} for c in self.test_clients}
      for c in self.test_clients:
          l, m = c.test()
          eval_statistics[c.client_id]["Loss"] = l
          eval_statistics[c.client_id]["Accuracy"] = m
      return eval_statistics

In [None]:
# SERVER GW
class Server:
    def __init__(
        self,
        num_clients_per_round,
        num_rounds,
        epochs_per_round,
        train_clients: list[Client],
        test_clients: list[Client],
        model: torch.nn.Module,
        use_prior,
        n_rounds_no_prior,
        random_state=300890,
    ):
        self.clients_per_round = num_clients_per_round
        self.num_rounds = num_rounds
        self.train_clients = train_clients
        self.test_clients = test_clients
        self.model = model
        self.epochs_per_round = epochs_per_round
        self.use_prior = use_prior
        self.n_rounds_no_prior = int(n_rounds_no_prior * num_rounds)
        self.weights = OrderedDict(
            (client.client_id, 1 / len(self.train_clients))
            for client in self.train_clients
        )
        self.weights_track = [self.weights.copy()]
        self.epochs_stds = None
        self.model_params_dict = copy.deepcopy(self.model.state_dict())
        self.updates = []
        self.prng = np.random.default_rng(random_state)

    def select_clients(self):
        """
        This method selects a random subset of `self.clients_per_round` clients
        from the given traning clients, without replacement.
        :return: list of clients
        """
        if self.use_prior == True:
            return self.prng.choice(
                self.train_clients,
                size=self.clients_per_round,
                replace=False,
                p=[w for _, w in self.weights.items()],
            )
        else:
            return self.prng.choice(
                self.train_clients, size=self.clients_per_round, replace=False
            )

    def load_model_on_clients(self):
        """
        This function loads the centralized model to the clients at
        the beginning of each training / testing round.
        """
        for c in self.test_clients + self.train_clients:
            c.model.load_state_dict(self.model_params_dict, strict=False)

    def train_round(self, clients: list[Client]):
        """
        This method trains the model with the dataset of the clients.
        It handles the training at single round level.
        The client updates are saved in the object-level list,
        they will be aggregated.
        :param clients: list of all the clients to train
        """
        client_loss = OrderedDict()
        for client in clients:
            client_loss[client.client_id] = client.train(self.epochs_per_round)
            self.updates.append((client.client_id, client.generate_update()))
        return client_loss

    def update_weights(self, client_loss: OrderedDict):
        """
        Updates the weigths saved at instance level.
        :param client_loss: a dictionary client_id -> list of round losses
        :return: the sum of the weights of the selected clients
        """
        # Normalize selected clients' weights
        w_sum = 0
        for client_id in client_loss.keys():
            w_sum += self.weights[client_id]
        for client_id in client_loss.keys():
            self.weights[client_id] = self.weights[client_id] / w_sum

        # Compute the mean process
        weights_sum = 0
        weights2_sum = 0
        loss_tensor = torch.empty(len(client_loss), self.epochs_per_round)
        mean_loss = torch.zeros(self.epochs_per_round)
        for idx, client_id in enumerate(client_loss.keys()):
            loss_tensor[idx] = client_loss[client_id]
            mean_loss += self.weights[client_id] * client_loss[client_id]
            weights_sum += self.weights[client_id] / len(client_loss)
            weights2_sum += (self.weights[client_id] ** 2) / len(client_loss)

        # Compute the standard deviation for each epoch
        sigma2 = ((loss_tensor - mean_loss) ** 2).sum(dim=0) / (len(client_loss) - 1)
        std_loss = sigma2 * (weights2_sum / weights_sum ** 2) * (1 / len(client_loss))

        self.epochs_stds = torch.sqrt(std_loss)

        # Compute the rewards for each epoch and each client
        exp_args_tensor = 0.5 * ((loss_tensor - mean_loss) ** 2) / std_loss
        reward_tensor = torch.exp(-exp_args_tensor).mean(dim=1)
        reward_tensor = reward_tensor / reward_tensor.sum()

        # Assign rewards to weights
        for idx, client_id in enumerate(client_loss.keys()):
            self.weights[client_id] = reward_tensor[idx].item() * w_sum
        return w_sum

    def aggregate(self, inv_scale_factor):
        """
        This method handles the FedAvg aggregation
        :param inv_scale_factor: scales the weights by the inverse of this factor
        :return: aggregated parameters
        """
        # Here we make the average of the updated weights
        base = OrderedDict()
        for client_id, client_model in self.updates:
            for key, value in client_model.items():
                if key in base:
                    base[key] += (
                        (1 / inv_scale_factor)
                        * self.weights[client_id]
                        * value.type(torch.FloatTensor)
                    )
                else:
                    base[key] = (
                        (1 / inv_scale_factor)
                        * self.weights[client_id]
                        * value.type(torch.FloatTensor)
                    )
        for key, value in base.items():
            self.model_params_dict[key] = value.to("cuda")

        self.model.load_state_dict(self.model_params_dict, strict=False)
        self.updates = []

    def train(self, path=None):
        """
        This method orchestrates the training the evals and tests at rounds level
        :return: Train / test statistics at each round. "Train as it happens" is the typical epoch loss returned from each client.
        """
        orchestra_statistics = {"Test": [], "Train as it happens": []}

        for r in range(self.num_rounds):
            self.load_model_on_clients()
            if (r >= self.n_rounds_no_prior) and (self.use_prior == False):
                self.use_prior = True
            print(f"Round {r + 1}")
            clients = self.select_clients()
            clients_loss = self.train_round(clients)
            constant_w_sum = self.update_weights(clients_loss)
            for key in clients_loss.keys():
                clients_loss[key] = clients_loss[key].cpu().tolist()
            orchestra_statistics["Train as it happens"].append(clients_loss)
            self.aggregate(constant_w_sum)

            # normalize all weights (this normalization pass is extra)
            weights_sum = 0
            for _, w in self.weights.items():
                weights_sum += w
            for client_id, w in self.weights.items():
                self.weights[client_id] = w / weights_sum
            self.weights_track.append(self.weights.copy())

            # # compute mean accuracy on train set
            # acc = 0
            # stats = self.eval_train()
            # for _, res in stats.items():
            #     acc += res["mIoU"]["Mean IoU"] / len(self.train_clients)
            # orchestra_statistics["Train"].append(acc)

            # compute mean accuracy on test set
            stats = self.test()
            orchestra_statistics["Test"].append(stats)

            if path is not None:
                self.save_checkpoint(path + f"_{r}.json", r)
        return orchestra_statistics

    def eval_train(self):
        """
        This method handles the evaluation on the train clients
        :return: dict (one key per client) of dicts (loss, miou) of scalars
        """
        self.load_model_on_clients()
        eval_statistics = {c.client_id: {} for c in self.train_clients}
        for c in self.train_clients:
            l, m = c.test()
            eval_statistics[c.client_id]["Loss"] = l
            eval_statistics[c.client_id]["Accuracy"] = m
        return eval_statistics

    def test(self):
        """
        This method handles the test on the test clients
        :return: dict (one key per client) of dicts (loss, miou) of scalars
        """
        self.load_model_on_clients()
        eval_statistics = {c.client_id: {} for c in self.test_clients}
        for c in self.test_clients:
            l, m = c.test()
            eval_statistics[c.client_id]["Loss"] = l
            eval_statistics[c.client_id]["Accuracy"] = m
        return eval_statistics

    def save_checkpoint(self, path, round):
        torch.save(
            {
                "round": round,
                "model_state_dict": self.model.state_dict()
            }, path
        )

**LeNet-5**

In [None]:
class LeNet5(nn.Module):

    def __init__(self, num_classes):
        super(LeNet5, self).__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1,
                      padding = 2),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.BatchNorm2d(6),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.BatchNorm2d(16)
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features = 400, out_features = 120),
            nn.Tanh(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=num_classes),
        )


    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        # return F.softmax(logits, dim=1)
        return logits

model = LeNet5(num_classes = 10).cuda()

**Fashion MNIST**

In [None]:
# Images standardizer
def images_scaler(X):
  X_std = torch.empty(len(X), 28, 28)
  y = torch.empty(len(X))
  for idx, sample in enumerate(X):
    x, label = sample
    mu = x.mean()
    std = torch.sqrt(torch.sum((x - mu) ** 2) / (28 ** 2 - 1))
    X_std[idx] = (x[0] - mu) / std
    y[idx] = label
  return X_std, y

In [None]:
# Load MNIST dataset
mnist_train = datasets.FashionMNIST('data', train=True, download=True, transform=T.ToTensor())
mnist_test = datasets.FashionMNIST('data', train=False, download=True, transform=T.ToTensor())

# Standardizzazione
X_dev, y_dev = images_scaler(mnist_train)
X_dev = X_dev.view(len(X_dev), 1, 28, 28)
X_eval, y_eval = images_scaler(mnist_test)
X_eval = X_eval.view(len(X_eval), 1, 28, 28)

In [None]:
# Generate clients datasets
dataset = [(X_dev[i].cuda(), y_dev[i].type(torch.LongTensor).cuda()) for i in range(len(X_dev))]
dataset_n1 = [((X_dev[i] + 1 * torch.rand(1, 28, 28)).cuda(), y_dev[i].type(torch.LongTensor).cuda()) for i in range(len(X_dev))]
dataset_n2 = [((X_dev[i] + 2 * torch.rand(1, 28, 28)).cuda(), y_dev[i].type(torch.LongTensor).cuda()) for i in range(len(X_dev))]
dataset_fl = dataset_fl = [(X_dev[i].cuda(), y_dev[i].type(torch.LongTensor).cuda()) for i in range(len(X_dev)) if y_dev[i] == 2]
train_clients_df = [train_test_split(dataset, y_dev, train_size = .3, shuffle = True, random_state = i)[0] \
    for i in range(18)] + [
    train_test_split(dataset_n1, y_dev, train_size = .3, shuffle = True, random_state = 100)[0],
    train_test_split(dataset_n1, y_dev, train_size = .3, shuffle = True, random_state = 101)[0],
    train_test_split(dataset_n2, y_dev, train_size = .3, shuffle = True, random_state = 102)[0],
    train_test_split(dataset_n2, y_dev, train_size = .3, shuffle = True, random_state = 103)[0],
    train_test_split(dataset_n2, y_dev, train_size = .3, shuffle = True, random_state = 104)[0],
    train_test_split(dataset_fl, y_dev[y_dev == 2], train_size = .3, shuffle = True, random_state = 5)[0]
]

test_clients_df = [
    train_test_split(dataset, y_dev, train_size = .3, shuffle = True, random_state = 301)[0],
    train_test_split(dataset, y_dev, train_size = .3, shuffle = True, random_state = 302)[0],
    train_test_split(dataset, y_dev, train_size = .3, shuffle = True, random_state = 303)[0],
]

# Initialize clients
train_clients = [
    Client(client_id = f'c_{i}', dataset = train_clients_df[i],
           batch_size = 128, model = deepcopy(model)) for i in range(len(train_clients_df))
]

# Test clients
test_clients = [
    Client(client_id = f'c_test_{i}', dataset = test_clients_df[i],
           batch_size = 256, model = deepcopy(model)) for i in range(3)
]

In [None]:
# Number of rounds
num_rounds = 20

# Choose Server (FedGW) or Server_fedavg (FedAvg)
strategy = 'FedGW'

# Training
if strategy == 'FedAvg':
  server = Server_fedavg(num_rounds = num_rounds, num_clients_per_round = len(train_clients),
                  epochs_per_round = 6,
                  train_clients = train_clients, test_clients = test_clients,
                  model = copy.deepcopy(model))
  results = server.train_evaluation()
elif strategy == 'FedGW':
  server = Server(num_rounds = num_rounds, num_clients_per_round = len(train_clients),
                  epochs_per_round = 6, use_prior = False, n_rounds_no_prior = .6,
                  train_clients = train_clients, test_clients = test_clients,
                  model = copy.deepcopy(model))
  results = server.train()