## Importing Libraries

In [5]:
# Importing necessary libraries for data processing, ML, and visualization
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import nltk
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
import json
import random
import re
import kagglehub
from copy import deepcopy
import warnings
import string
import itertools

from torch.utils.data import Dataset, random_split, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts

# Suppressing warnings for a cleaner output
warnings.filterwarnings("ignore")

# Download Shakespeare Dataset
path = kagglehub.dataset_download("kewagbln/shakespeareonline")
DATA_PATH = os.path.join(path, "t8.shakespeare.txt")
DATA_DIR = "data/"

TRAIN_FRACTION = 0.9
SEQ_LEN = 80
N_VOCAB = 90



## Checkpoints

In [6]:


# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create a Folder in the root directory
!mkdir -p "/content/drive/My Drive/My Folder/checkpoints_shakespeare"

CHECKPOINT_DIR = '/content/drive/My Drive/My Folder/checkpoints_shakespeare'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, hyperparameters, subfolder="", data_to_save=None):
    """Salva il checkpoint del modello e rimuove quello precedente."""
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    os.makedirs(subfolder_path, exist_ok=True)

    # File corrente e precedente
    filename = f"model_epoch_{epoch}_params_{hyperparameters}.pth"
    filepath = os.path.join(subfolder_path, filename)
    filename_json = f"model_epoch_{epoch}_params_{hyperparameters}.json"
    filepath_json = os.path.join(subfolder_path, filename_json)


    previous_filename = f"model_epoch_{epoch -1}_params_{hyperparameters}.pth"
    previous_filepath = os.path.join(subfolder_path, previous_filename)
    previous_filename_json = f"model_epoch_{epoch -1}_params_{hyperparameters}.json"
    previous_filepath_json = os.path.join(subfolder_path, previous_filename_json)

    # Rimuove il checkpoint precedente
    if epoch > 1 and os.path.exists(previous_filepath) and os.path.exists(previous_filepath_json):
        os.remove(previous_filepath)
        os.remove(previous_filepath_json)

    # Salva il nuovo checkpoint
    if optimizer is not None:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),  # Salvataggio dello stato dell'ottimizzatore
            'epoch': epoch
        }, filepath)
    else:
        torch.save({
            'model_state_dict': model.state_dict(),
            'epoch': epoch
        }, filepath)
    print(f"Checkpoint salvato: {filepath}")

    with open(filepath_json, 'w') as json_file:
      json.dump(data_to_save, json_file, indent=4)


def load_checkpoint(model, optimizer, hyperparameters, subfolder=""):
    """Carica l'ultimo checkpoint disponibile basato sugli iperparametri."""
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    if not os.path.exists(subfolder_path):
        print("No checkpoint found, Starting now...")
        return 1, None  # Le epoche iniziano da 1

    # Cerca i file con gli iperparametri specificati
    files = [f for f in os.listdir(subfolder_path) if f"params_{hyperparameters}" in f and f.endswith('.pth')]
    if files:
        # Trova il file con l'epoca più alta
        latest_file = max(files, key=lambda x: int(x.split('_')[2]))
        filepath = os.path.join(subfolder_path, latest_file)
        checkpoint = torch.load(filepath)

        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Trova e carica il file JSON associato
        json_filename = latest_file.replace('.pth', '.json')
        json_filepath = os.path.join(subfolder_path, json_filename)
        json_data = None
        if os.path.exists(json_filepath):
            with open(json_filepath, 'r') as json_file:
                json_data = json.load(json_file)
            print(f"JSON data loaded: {json_filepath}")
        else:
            print(f"No JSON file found for: {latest_file}")

        print(f"Checkpoint found: Resume epoch {checkpoint['epoch'] + 1}")
        return checkpoint['epoch'] + 1, json_data

    print("No checkpoint found, Starting now...")
    return 1, None  # Le epoche iniziano da 1



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Text Processing

In [7]:
TRAIN_FRACTION = 0.9

CHARACTER_RE = re.compile(r'^  ([a-zA-Z][a-zA-Z ]*)\. (.*)')  # Matches character lines
CONT_RE = re.compile(r'^    (.*)')  # Matches continuation lines
COE_CHARACTER_RE = re.compile(r'^([a-zA-Z][a-zA-Z ]*)\. (.*)')  # Special regex for Comedy of Errors
COE_CONT_RE = re.compile(r'^(.*)')  # Continuation for Comedy of Errors

def parse_shakespeare_file(filepath):
    """
    Reads and splits Shakespeare's text into plays, characters, and their dialogues.
    Returns training and test datasets based on the specified fraction.
    """
    with open(filepath, "r") as f:
        content = f.read()
    plays, _ = _split_into_plays(content)  # Split the text into plays
    _, train_examples, test_examples = _get_train_test_by_character(
        plays, test_fraction=1 - TRAIN_FRACTION
    )
    return train_examples, test_examples

def _split_into_plays(shakespeare_full):
    """
    Splits the full Shakespeare text into individual plays and characters' dialogues.
    Handles special parsing for "The Comedy of Errors".
    """
    plays = []
    slines = shakespeare_full.splitlines(True)[1:]  # Skip the first line (title/header)
    current_character = None
    comedy_of_errors = False

    for i, line in enumerate(slines):
        # Detect play titles and initialize character dictionary
        if "by William Shakespeare" in line:
            current_character = None
            characters = defaultdict(list)
            title = slines[i - 2].strip() if slines[i - 2].strip() else slines[i - 3].strip()
            comedy_of_errors = title == "THE COMEDY OF ERRORS"
            plays.append((title, characters))
            continue

        # Match character lines or continuation lines
        match = _match_character_regex(line, comedy_of_errors)
        if match:
            character, snippet = match.group(1).upper(), match.group(2)
            if not (comedy_of_errors and character.startswith("ACT ")):
                characters[character].append(snippet)
                current_character = character
        elif current_character:
            match = _match_continuation_regex(line, comedy_of_errors)
            if match:
                characters[current_character].append(match.group(1))

    # Filter out plays with insufficient dialogue data
    return [play for play in plays if len(play[1]) > 1], []

def _match_character_regex(line, comedy_of_errors=False):
    """Matches character dialogues, with special handling for 'The Comedy of Errors'."""
    return COE_CHARACTER_RE.match(line) if comedy_of_errors else CHARACTER_RE.match(line)

def _match_continuation_regex(line, comedy_of_errors=False):
    """Matches continuation lines of dialogues."""
    return COE_CONT_RE.match(line) if comedy_of_errors else CONT_RE.match(line)

def _get_train_test_by_character(plays, test_fraction=0.2):
    """
    Splits dialogues by characters into training and testing datasets.
    Ensures each character has at least one example in the training set.
    """
    all_train_examples = defaultdict(list)
    all_test_examples = defaultdict(list)

    def add_examples(example_dict, example_tuple_list):
        """Adds examples to the respective dataset dictionary."""
        for play, character, sound_bite in example_tuple_list:
            example_dict[f"{play}_{character}".replace(" ", "_")].append(sound_bite)

    for play, characters in plays:
        for character, sound_bites in characters.items():
            examples = [(play, character, sound_bite) for sound_bite in sound_bites]
            if len(examples) <= 2:
                continue

            # Calculate the number of test samples
            num_test = max(1, int(len(examples) * test_fraction))
            num_test = min(num_test, len(examples) - 1)  # Ensure at least one training example

            # Split into train and test sets
            train_examples = examples[:-num_test]
            test_examples = examples[-num_test:]

            add_examples(all_train_examples, train_examples)
            add_examples(all_test_examples, test_examples)

    return {}, all_train_examples, all_test_examples


def letter_to_vec(c, n_vocab=128):
    """Converts a single character to a vector index based on the vocabulary size."""
    return ord(c) % n_vocab

def word_to_indices(word, n_vocab=128):
    """
    Converts a word or list of words into a list of indices.
    Each character is mapped to an index based on the vocabulary size.
    """
    if isinstance(word, list):  # If input is a list of words
        res = []
        for stringa in word:
            res.extend([ord(c) % n_vocab for c in stringa])  # Convert each word to indices
        return res
    else:  # If input is a single word
        return [ord(c) % n_vocab for c in word]

def process_x(raw_x_batch, seq_len, n_vocab):
    """
    Processes raw input data into padded sequences of indices.
    Ensures all sequences are of uniform length.
    """
    x_batch = [word_to_indices(word, n_vocab) for word in raw_x_batch]
    x_batch = [x[:seq_len] + [0] * (seq_len - len(x)) for x in x_batch]
    return torch.tensor(x_batch, dtype=torch.long)


def process_y(raw_y_batch, seq_len, n_vocab):
    """
    Processes raw target data into padded sequences of indices.
    Shifts the sequence by one character to the right.
    y[1:seq_len + 1] takes the input data, right shift of an
    element and uses the next element of the sequence to fill
    and at the end (with [0]) final padding (zeros) are (eventually)
    added to reach the desired sequence length.
    """
    y_batch = [word_to_indices(word, n_vocab) for word in raw_y_batch]
    y_batch = [y[1:seq_len + 1] + [0] * (seq_len - len(y[1:seq_len + 1])) for y in y_batch]  # Shifting and final padding
    return torch.tensor(y_batch, dtype=torch.long)

def create_batches(data, batch_size, seq_len, n_vocab):
    """
    Creates batches of input and target data from dialogues.
    Each batch contains sequences of uniform length.
    """
    x_batches = []
    y_batches = []
    dialogues = list(data.values())
    random.shuffle(dialogues)  # Shuffle to ensure randomness in batches

    batch = []
    for dialogue in dialogues:
        batch.append(dialogue)
        if len(batch) == batch_size:
            x_batch = process_x(batch, seq_len, n_vocab)
            y_batch = process_y(batch, seq_len, n_vocab)
            x_batches.append(x_batch)
            y_batches.append(y_batch)
            batch = []

    # Add the last batch if it's not full
    if batch:
        x_batch = process_x(batch, seq_len, n_vocab)
        y_batch = process_y(batch, seq_len, n_vocab)
        x_batches.append(x_batch)
        y_batches.append(y_batch)

    return x_batches, y_batches

## Shakespeare Dataset

In [9]:
class ShakespeareDataset(Dataset):
    def __init__(self, data, seq_len, n_vocab):
        self.data = list(data.values())  # Convert the dictionary values to a list
        self.seq_len = seq_len  # Sequence length for the model
        self.n_vocab = n_vocab  # Vocabulary size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        dialogue = self.data[idx]
        x = process_x([dialogue], self.seq_len, self.n_vocab)[0]
        y = process_y([dialogue], self.seq_len, self.n_vocab)[0]
        return x, y

data_train, data_test = parse_shakespeare_file(DATA_PATH)
dataset = ShakespeareDataset(data_train, seq_len=80, n_vocab=90)
len(dataset)

1164

## Shakespeare Model Architecture

In [8]:
class ShakespeareLSTM(nn.Module):
    def __init__(self, vocab_size=90, embed_dim=8, lstm_hidden_dim=256, seq_len=80, batch_size=32):
        super(ShakespeareLSTM, self).__init__()

        self.seq_len = seq_len
        self.batch_size = batch_size
        self.lstm_hidden_dim = lstm_hidden_dim
        self.vocab_size = vocab_size
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Embedding layer
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

        # First LSTM layer
        self.lstm1 = nn.LSTM(input_size=embed_dim, hidden_size=lstm_hidden_dim, batch_first=True)

        # Second LSTM layer
        self.lstm2 = nn.LSTM(input_size=lstm_hidden_dim, hidden_size=lstm_hidden_dim, batch_first=True)

        # Dense output layer
        self.dense = nn.Linear(lstm_hidden_dim, vocab_size)

    def init_hidden(self, batch_size):
        """
        Inizializza lo stato nascosto e la cella della LSTM come tensori di zeri.
        """
        h0 = torch.zeros(2, batch_size, self.lstm_hidden_dim).to(self.device)
        c0 = torch.zeros(2, batch_size, self.lstm_hidden_dim).to(self.device)
        return (h0, c0)

    def forward(self, x, hidden=None):
        # x: (batch_size, seq_len)
        batch_size = x.size(0)

        x = self.embedding(x)  # (batch_size, seq_len, embed_dim)

        x, hidden = self.lstm1(x, hidden)  # (batch_size, seq_len, lstm_hidden_dim)

        x, hidden = self.lstm2(x, hidden)  # (batch_size, seq_len, lstm_hidden_dim)

        x = self.dense(x)  # (batch_size, seq_len, vocab_size)

        return x, hidden

## Centralized training functions

In [17]:
def train_model(model, train_loader, validation_loader, test_loader, optimizer, scheduler, criterion, epochs, hyperparameters):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    train_losses, validation_losses, validation_accuracies = [], [], []

    # Carica checkpoint se esiste
    start_epoch, json_data = load_checkpoint(model, optimizer, hyperparameters, "Centralized/")
    if json_data is not None:
        validation_losses = json_data.get('validation_losses', [])
        validation_accuracies = json_data.get('validation_accuracies', [])
        train_losses = json_data.get('train_losses', [])

    if start_epoch >= epochs:
        print(f"Checkpoint trovato, configurazione già completata. Valutazione solo sul validation set.")
        validation_loss, validation_accuracy = evaluate_model(model, validation_loader, criterion, device)
        validation_losses.append(validation_loss)
        validation_accuracies.append(validation_accuracy)
        return train_losses, validation_losses, validation_accuracies

    import time

    for epoch in range(start_epoch, epochs + 1):

        model.train()
        epoch_loss = 0
        x_batches, y_batches = create_batches(train_data, BATCH_SIZE, SEQ_LEN, N_VOCAB)

        for x_batch, y_batch in zip(x_batches, y_batches):

            inputs, targets = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            outputs, _ = model(inputs)
            outputs = outputs.view(-1, 90)
            targets = targets.view(-1)
            loss = criterion(outputs, targets)  # Calcola la loss sull'ultima previsione

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        scheduler.step()

        # Valutazione sul validation set
        validation_loss, validation_accuracy = evaluate_model(model, validation_loader, criterion, device)
        train_losses.append(epoch_loss / len(train_loader))
        validation_losses.append(validation_loss)
        validation_accuracies.append(validation_accuracy)

        # Salva checkpoint
        save_checkpoint(
            model, optimizer, epoch, hyperparameters, "Centralized/",
            data_to_save={
                'validation_losses': validation_losses,
                'validation_accuracies': validation_accuracies,
                'train_losses': train_losses
            }
        )

        print(f"Epoch {epoch}/{epochs}, Train Loss: {epoch_loss:.4f}, "
              f"Validation Loss: {validation_loss:.4f}, Validation Accuracy: {validation_accuracy:.4f}, ")

    # Valutazione sul test set
    test_loss, test_accuracy = evaluate_model(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    return train_losses, validation_losses, validation_accuracies




def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs, _ = model(inputs)
            outputs = outputs.view(-1, model.vocab_size)
            targets = targets.view(-1)

            loss = criterion(outputs, targets)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)

    return total_loss / len(test_loader), correct / total



## Centralized training

In [112]:
# Always run before creating new datasets

if os.path.exists("/content/leaf/"):
  # Use shutil.rmtree to remove the folder and its contents
  shutil.rmtree("/content/leaf")
  print(f"Successfully deleted folder leaf")

os.chdir("/content/")
!git clone https://github.com/maxfra01/leaf.git

# -----------------------------------------

preprocess_params = {
        'sharding': 'iid',
        'sf': 1.0,
        'iu': 0.1,
        't': 'sample',
        'tf': 0.8,
    } # Get the full-size dataset

train_dataset_big = CentralizedShakespeareDataset(root="/content/leaf/data/shakespeare", split="train", preprocess_params=preprocess_params)
test_dataset = CentralizedShakespeareDataset(root="/content/leaf/data/shakespeare", split="test", preprocess_params=preprocess_params)


Successfully deleted folder leaf
Cloning into 'leaf'...
remote: Enumerating objects: 772, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 772 (delta 0), reused 0 (delta 0), pack-reused 766 (from 1)[K
Receiving objects: 100% (772/772), 6.78 MiB | 28.94 MiB/s, done.
Resolving deltas: 100% (363/363), done.
Running command: bash preprocess.sh -s iid --iu 0.1 --sf 1.0 -t sample --tf 0.8
Absolute folder path: /content/leaf/data/shakespeare/data/train


NameError: name 'text_to_indexes' is not defined

In [None]:
# Hyperparameters
BATCH_SIZE = 4
LEARNING_RATE = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY=1e-4
EPOCHS = 20

SEQ_LEN = 80
N_VOCAB = 90

hyperparameters = f"BS{BATCH_SIZE}_LR{LEARNING_RATE}_WD{WEIGHT_DECAY}_M{MOMENTUM}"


model_shakespeare = ShakespeareLSTM()

train_data, test_data = parse_shakespeare_file(DATA_PATH)

train_dataset = ShakespeareDataset(train_data, seq_len=SEQ_LEN, n_vocab=N_VOCAB)
test_dataset = ShakespeareDataset(test_data, seq_len=SEQ_LEN, n_vocab=N_VOCAB)

# Split the train dataset into train and validation:
train_size = int(TRAIN_FRACTION * len(train_dataset))  # 90%
valid_size = len(train_dataset) - train_size  # 10%
#random split:
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])

# Creation of the DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

optimizer = optim.SGD(
    model_shakespeare.parameters(),
    lr=LEARNING_RATE,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion = nn.CrossEntropyLoss()

# Train the model
train_losses, val_losses, val_accuracies = train_model(
    model=model_shakespeare,
    train_loader=train_dataloader,
    validation_loader= val_dataloader,
    test_loader=test_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    epochs=EPOCHS,
    hyperparameters=hyperparameters
)

# Evaluation on test split


plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(val_losses, label='Shakespeare Val Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Shakespare Val Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show()


No checkpoint found, Starting now...
Checkpoint salvato: /content/drive/My Drive/My Folder/checkpoints_shakespeare/Centralized/model_epoch_1_params_BS4_LR0.1_WD0.0001_M0.9.pth
Epoch 1/20, Train Loss: 934.4789, Validation Loss: 3.0023, Validation Accuracy: 0.2064, 
Checkpoint salvato: /content/drive/My Drive/My Folder/checkpoints_shakespeare/Centralized/model_epoch_2_params_BS4_LR0.1_WD0.0001_M0.9.pth
Epoch 2/20, Train Loss: 814.6838, Validation Loss: 2.5952, Validation Accuracy: 0.2946, 
Checkpoint salvato: /content/drive/My Drive/My Folder/checkpoints_shakespeare/Centralized/model_epoch_3_params_BS4_LR0.1_WD0.0001_M0.9.pth
Epoch 3/20, Train Loss: 731.7919, Validation Loss: 2.4180, Validation Accuracy: 0.3390, 
Checkpoint salvato: /content/drive/My Drive/My Folder/checkpoints_shakespeare/Centralized/model_epoch_4_params_BS4_LR0.1_WD0.0001_M0.9.pth
Epoch 4/20, Train Loss: 688.5215, Validation Loss: 2.2971, Validation Accuracy: 0.3757, 
Checkpoint salvato: /content/drive/My Drive/My Fold

## Federate Learning classes

In [None]:
def generate_skewed_probabilities(num_clients, gamma):
    """It generates skewed probabilities for clients using a Dirichlet distribution."""
    probabilities = np.random.dirichlet([gamma] * num_clients)
    return probabilities

def plot_selected_clients_distribution(selected_clients_per_round, num_clients, hyperparameters):
    """Plotta la distribuzione dei client selezionati alla fine del processo."""
    counts = np.zeros(num_clients)

    # Conta quante volte ogni client è stato selezionato in tutti i round
    for selected_clients in selected_clients_per_round:
        for client in selected_clients:
            counts[client] += 1

    plt.figure(figsize=(10, 6))
    plt.bar(range(num_clients), counts, color='skyblue', edgecolor='black')
    plt.title("Distribuzione dei Client Selezionati Durante il Federated Averaging")
    plt.xlabel("Client ID")
    plt.ylabel("Frequenza di Selezione")
    plt.grid(axis='y')
    plt.savefig(f"Shakespeare_Client_distribution_{hyperparameters}.png")
    plt.show()


class Client:

  def __init__(self, model, client_id, data, optimizer_params):
    self.client_id = client_id
    self.data = data
    self.model = model
    self.optimizer_params = optimizer_params

  def train(self, global_weights, local_steps, batch_size):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.model.to(device)
    self.model.load_state_dict(global_weights)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        self.model.parameters(),
        lr=self.optimizer_params['lr'],
        momentum=self.optimizer_params['momentum'],
        weight_decay=self.optimizer_params['weight_decay']
        )
    trainloader = DataLoader(self.data, batch_size=batch_size, shuffle=True,  pin_memory=True)

    steps = 0
    while steps < local_steps:
      for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        steps += 1
        if steps >= local_steps:
          break
    return self.model.state_dict()

class Server:

  def __init__(self, model, clients, test_data, val_data):
    self.model = model
    self.clients = clients
    self.val_data = val_data
    self.test_data = test_data
    self.round_losses = []
    self.round_accuracies = []
    self.selected_clients_per_round = [] #clint selezionati per skewness

  def federated_averaging(self, local_steps, batch_size, num_rounds, fraction_fit, skewness = None, hyperparameters = None):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.model.to(device)
     # Carica il checkpoint se esiste
    data_to_load = None
    if skewness is  None:
      start_epoch, data_to_load = load_checkpoint(self.model,optimizer=None,hyperparameters=hyperparameters, subfolder="Federated_Uniform/")
    else:
      start_epoch, data_to_load = load_checkpoint(self.model,optimizer=None,hyperparameters=hyperparameters, subfolder="Federated_Skewed/")

    if data_to_load is not None:
      self.round_losses = data_to_load['round_losses']
      self.round_accuracies = data_to_load['round_accuracies']
      self.selected_clients_per_round = data_to_load['selected_clients_per_round']


    for round in range(start_epoch, num_rounds+1):

      if skewness is not None:
        probabilities = generate_skewed_probabilities(len(self.clients), skewness)
        selected_clients = np.random.choice(self.clients, size=max(1, int(fraction_fit*len(self.clients))), replace=False, p=probabilities)

      else:
        selected_clients = np.random.choice(self.clients, size=max(1, int(fraction_fit*len(self.clients))), replace=False)

      self.selected_clients_per_round.append([client.client_id for client in selected_clients])


      global_weights = self.model.state_dict()

      # Simulating parallel clients training
      client_weights = {}
      for client in selected_clients:
        client_weights[client.client_id] = client.train(global_weights, local_steps, batch_size)

      new_global_weights = {key: torch.zeros_like(value).type(torch.float32) for key, value in global_weights.items()}

      total_data_size = sum([len(client.data) for client in selected_clients])
      for client in selected_clients:
        scaling_factor = len(client.data) / total_data_size
        for key in new_global_weights.keys():
          new_global_weights[key] += scaling_factor * client_weights[client.client_id][key]

      # Update global model weights
      self.model.load_state_dict(new_global_weights)

      # Evaluate global model every 10 rounds
      if round % 10 == 0:
        loss, accuracy = evaluate_model(self.model, DataLoader(self.val_data, batch_size=batch_size, shuffle=True, pin_memory=True), nn.CrossEntropyLoss(), device)
        self.round_losses.append(loss)
        self.round_accuracies.append(accuracy)
        print(f"Round {round}/{num_rounds} - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

        data_to_save = {
          'round_losses': self.round_losses,
          'round_accuracies': self.round_accuracies,
          'selected_clients_per_round': [[client for client in round_clients] for round_clients in self.selected_clients_per_round]  # Serializziamo solo i client_id
      }

        if skewness is  None:
          save_checkpoint(self.model, None, round , hyperparameters, "Federated_Uniform/", data_to_save)
        else:
          save_checkpoint(self.model, None, round , hyperparameters, "Federated_Skewed/", data_to_save)



    print("Evaluation on test set...")
    loss, accuracy = evaluate_model(self.model, DataLoader(self.test_data, batch_size=batch_size, shuffle=True, pin_memory=True), nn.CrossEntropyLoss(), device)
    print(f"Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}")

    plt.figure(figsize=(12,5))
    plt.subplot(1, 2, 1)
    plt.plot(self.round_losses, label='Shakespeare Validation Loss')
    plt.xlabel('Round (x10)')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(self.round_accuracies, label='Shakespeare Validation Accuracy')
    plt.xlabel('Round (x10)')
    plt.ylabel('Accuracy')
    plt.legend()
    if skewness is  None:
      plt.savefig(f"Shakespeare_fedavg_uniform{hyperparameters}.jpg")
    else:
      plt.savefig(f"Shakespeare_fedavg_skew{hyperparameters}.jpg")

    plt.show()

    plot_selected_clients_distribution(self.selected_clients_per_round, len(self.clients), hyperparameters)

