## Importing Libraries

In [None]:
# Standard library imports
import os
import sys
import json
import random
from collections import OrderedDict

# Third-party library imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Google Colab specific imports
from google.colab import drive

# Set the working directory
DIR_DATA = '/content/'
os.chdir(DIR_DATA)


## Centralized training functions

In [None]:
def train_model(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    train_losses, test_losses, test_accuracies = [], [], []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # Step the scheduler
        if scheduler is not None:
          scheduler.step()

        # Evaluate on test set
        test_loss, test_accuracy = evaluate_model(model, test_loader, criterion, device)
        train_losses.append(epoch_loss / len(train_loader))
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}, "
              f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    return train_losses, test_losses, test_accuracies

def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)

    return total_loss / len(test_loader), correct / total


## Federate Learning classes

In [None]:
def generate_skewed_probabilities(num_clients, gamma):
    """It generates skewed probabilities for clients using a Dirichlet distribution."""
    probabilities = np.random.dirichlet([gamma] * num_clients)
    return probabilities


class Client:

  def __init__(self, model, client_id, data, optimizer_params):
    self.client_id = client_id
    self.data = data
    self.model = model
    self.optimizer_params = optimizer_params

  def train(self, global_weights, epochs, batch_size):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.model.to(device)
    self.model.load_state_dict(global_weights)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        self.model.parameters(),
        lr=self.optimizer_params['lr'],
        momentum=self.optimizer_params['momentum'],
        weight_decay=self.optimizer_params['weight_decay']
        )
    trainloader = DataLoader(self.data, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
      print(f"Client {self.client_id}, Epoch {epoch+1}/{epochs}")
      for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    return self.model.state_dict()



class Server:
  def __init__(self, model, clients, test_data):
    self.model = model
    self.clients = clients
    self.test_data = test_data
    self.round_losses = []
    self.round_accuracies = []

  def federated_averaging(self, epochs, batch_size, num_rounds, fraction_fit, skewness=None, fedOptimizer=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.model.to(device)

    # Initialize variables for FedOptimizers
    if fedOptimizer in {"FedAdaGrad", "FedYogi", "FedAdam"}:
        optimizer_state = {
            "m": {key: torch.zeros_like(value, dtype=torch.float32) for key, value in self.model.state_dict().items()},
            "v": {key: torch.zeros_like(value, dtype=torch.float32) for key, value in self.model.state_dict().items()},
        }
        beta1 = 0.9  # Momentum parameter for Adam-based optimizers
        beta2 = 0.999  # RMS parameter for Adam-based optimizers
        lr = 0.01  # Learning rate
        eps = 1e-8  # Small constant to prevent division by zero

    for round in range(num_rounds):
        print(f"Round {round + 1}/{num_rounds}")

        if skewness is not None:
            probabilities = generate_skewed_probabilities(len(self.clients), skewness)
            selected_clients = np.random.choice(self.clients, size=max(1, int(fraction_fit * len(self.clients))),
                                                replace=False, p=probabilities)
        else:
            selected_clients = np.random.choice(self.clients, size=max(1, int(fraction_fit * len(self.clients))),
                                                replace=False)

        global_weights = self.model.state_dict()

        # Simulate parallel client training
        client_weights = {}
        for client in selected_clients:
            client_weights[client.client_id] = client.train(global_weights, epochs, batch_size)

        # Aggregate client updates
        total_data_size = sum([len(client.data) for client in selected_clients])
        aggregated_updates = {key: torch.zeros_like(value, dtype=torch.float32) for key, value in global_weights.items()}

        for client in selected_clients:
            scaling_factor = len(client.data) / total_data_size
            for key in aggregated_updates.keys():
                aggregated_updates[key] += scaling_factor * (client_weights[client.client_id][key] - global_weights[key])

        # Apply selected FedOptimizer
        if fedOptimizer == "FedAdaGrad":
            for key in global_weights.keys():
                optimizer_state["v"][key] += aggregated_updates[key] ** 2
                global_weights[key] += lr * aggregated_updates[key] / (torch.sqrt(optimizer_state["v"][key]) + eps)

        elif fedOptimizer == "FedYogi":
            for key in global_weights.keys():
                optimizer_state["v"][key] -= (1 - beta2) * aggregated_updates[key] ** 2 * torch.sign(
                    optimizer_state["v"][key] - aggregated_updates[key] ** 2)
                global_weights[key] += lr * aggregated_updates[key] / (torch.sqrt(optimizer_state["v"][key]) + eps)

        elif fedOptimizer == "FedAdam":
            for key in global_weights.keys():
                optimizer_state["m"][key] = beta1 * optimizer_state["m"][key] + (1 - beta1) * aggregated_updates[key]
                optimizer_state["v"][key] = beta2 * optimizer_state["v"][key] + (1 - beta2) * aggregated_updates[key] ** 2
                m_hat = optimizer_state["m"][key] / (1 - beta1 ** (round + 1))
                v_hat = optimizer_state["v"][key] / (1 - beta2 ** (round + 1))
                global_weights[key] += lr * m_hat / (torch.sqrt(v_hat) + eps)

        else:  # Default to FedAvg
            for key in global_weights.keys():
                global_weights[key] += aggregated_updates[key]

        # Update global model weights
        self.model.load_state_dict(global_weights)

        # Evaluate global model
        loss, accuracy = evaluate_model(self.model, DataLoader(self.test_data, batch_size=batch_size, shuffle=True),
                                        nn.CrossEntropyLoss(), device)
        self.round_losses.append(loss)
        self.round_accuracies.append(accuracy)
        print(f"Round {round + 1}/{num_rounds} - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

    # Plot results
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(self.round_losses, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(self.round_accuracies, label='Test Accuracy')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()