# Federated Learning from scratch with Pytorch
**A Pytorch implementation of the Federated Learning Algorithm FedAvg on MNIST dataset**

## Introduction
Federated Learning is a machine learning setting where many clients (e.g. mobile devices or whole organizations) collaboratively train a model under the orchestration of a central server (e.g. service provider), while keeping the training data decentralized. Federated Learning represents a possible solution to the problem of data privacy and data security in machine learning.

In this project I implemented the Federated Learning algorithm **FedAvg** formulated by McMahan, et al. [1] which is a first approach to this challange.

<img src="images/algo.png" alt="Image Description" width="400" height="400">

Basically the algorithm is divided into two phases: the **Client Update** and the **Server Update**. In the **Client Update** phase, each client trains its local model on its local sample of data for the number of local epochs, using the local batch size. In the **Server Update** phase, the server averages the models weights of all the clients to create a new global model. This global model is then sent to all the clients to start a new round of training. In this way, the server never has access to the raw data but only to the model weights, and the global model is a good approximation of the model that would have been obtained by training on the entire dataset.

## Implementation

I implemented the algorithm using Pytorch and I tested it on the MNIST dataset for semplicity. 

The project is divided into the following files:
- `fedsgd.py`: the main file that contains the implementation of the FedAvg algorithm
- `Client.py`: the file that contains the implementation of the `Client` class
- `Server.py`: the file that contains the implementation of the `Server` class
- `models.py`: the file that contains the implementation of the neural models

### Client
The `Client.py` file contains the implementation of the client class. Each client has its own dataloader, model and optimizer, the learning rate and the number of local epochs can be set by the user. The method `train` implements the **Client Update** phase.

```python
def train(self, num_epochs, patience, params, progress_bar):
    self.model.load_state_dict(params)
    self.model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (x, y) in enumerate(self.data):
            x, y = x.to(self.device), y.to(self.device)
            self.optimizer.zero_grad()
            y_pred = self.model(x)
            loss = self.loss_func(y_pred, y)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()
        progress_bar.update(1)
    progress_bar.set_postfix({"loss": running_loss / len(self.data)})
    return running_loss / len(self.data)
```

### Server
The `Server.py` file contains the implementation of the server class. The server has a model and a list of clients, it also has the method `aggregate` that implements the **Server Update** phase. Notice that the function does not return the new global model but it updates the model inplace.

```python
def aggregate(self):
    params = [client.get_params() for client in self.clients]
    avg_params = {}
    for key in params[0].keys():
        avg_params[key] = torch.stack(
            [params[i][key] for i in range(len(params))], 0
        ).mean(0)
    self.model.load_state_dict(avg_params)
```

### Algorithm Loop
Finally the `fedsgd.py` file contains the main function that runs the experiment with the right configuration. The function `fedSgdSeq` runs the sequential version of the algorithm, while the function `fedSgdPar` runs the parallel version of the algorithm. The algorithm run for `T` rounds, first the server get the current global model state and then sends it to all the clients. Next the clients train their local model, which they updated to the new one, moving into their local optimizer direction for local epochs step and then they send the model back to the server. The server aggregates the models and sends the new global model to all the clients. At the end of the training the server evaluates the model on the test set.

```python
def fedSgdSeq(
    model=Cnn(),
    T=5,                # number of rounds
    K=10,               # number of clients
    C=1,                # fraction of clients
    E=10,               # number of local epochs
    B=128,              # local batch size
    num_samples=1000,   # number of training samples on each client
    lr=0.01,            # learning rate
    weight_decay=10e-6, # weight decay
    patience=5,         # patience for early stopping
):
    # ... code ...
    clients = []
    for i in range(num_clients):
        client = Client(
            i,
            trainloader[i],
            Cnn() if model.get_type() == "Cnn" else Net(),
            lr=lr,
            weight_decay=weight_decay,
            device=device,
        )
        clients.append(client)
    # ... code ...
    # FedAvg algorithm sequential version
    for r in range(T):
        params = server.get_params()
        progress_bar = tqdm.tqdm(
            total=E * num_clients, position=0, leave=False, desc="Round %d" % r
        )
        for client in clients:
            loss = client.train(E, patience, params, progress_bar)
        server.aggregate()
        val_loss, val_acc = server.test(valoader)
        print("Server - Val loss: %.3f, Val accuracy: %.3f" % (val_loss, val_acc))
```

The dataset is splitted into training, validation and test set. The training set is divided into `K` clients, each client has a dataloader that samples `num_samples` samples from the training set. The validation set is used to evaluate the model at the end of each round, while the test set is used to evaluate the model at the end of the training. 

The parallel version uses the `joblilb` module to create a thread for each client, in this way we simulate better the real scenario where the clients are distributed on different devices.

```python
for r in range(T):
    params = server.get_params()
    progress_bar = tqdm.tqdm(
        total=E * num_clients, position=0, leave=False, desc="Round %d" % r
    )
    joblib.Parallel(n_jobs=num_clients, backend="threading")(
        joblib.delayed(client.train)(E, patience, params, progress_bar)
        for client in clients
    )
    server.aggregate()
    val_loss, val_acc = server.test(valoader)
    print("Server - Val loss: %.3f, Val accuracy: %.3f" % (val_loss, val_acc))
```

This implemenation allows to test different configurations of the algorithm. In particular if we set `C = 1`, `E = 1` and `B = inf` we obtain the **FedSGD** algorithm, which is the baseline algorithm for the Federated Learning presented in the paper [1].

## Experiments

In [1]:
from src.fedsgd import fedSgdPar, fedSgdSeq
from src.models import Net, Cnn

result = fedSgdPar(model=Cnn(), T=20, K=100, C=0.1, E=5, B=10, num_samples=480, lr=0.1, patience=5)

Running the Parallel implementation FedSGD on MNIST dataset
- Parameters: T=20, K=100, C=0.1, E=5, B=10, num_samples=480, lr=0.1, weight_decay=0, patience=5
- Model: Cnn
- Data Split:  48000 12000 10000


Round 0: 100%|██████████| 50/50 [00:14<00:00,  4.56it/s, loss=0.12]  

Server - Val loss: 0.377, Val accuracy: 0.926


Round 1: 100%|██████████| 50/50 [4:53:13<00:00,  7.71s/it, loss=0.00405]     

Server - Val loss: 0.135, Val accuracy: 0.961


Round 2: 100%|██████████| 50/50 [00:12<00:00,  4.83it/s, loss=0.00427]  

## Results

## Considerations

## References
[1] [H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, Communication-Efficient Learning of Deep Networks from Decentralized Data. 2023.](https://arxiv.org/abs/1602.05629)

## TODO 
- [x] Add implementation code description: 
	- [x] `Client` class
	- [x] `Server` class
	- [x] `models` class
	- [x] `fedSgdPar` and `fedSgd` algorithm (show the algorithm code)
- [x] Check parallel client implementation (can be improved?)
- [ ] Add LSTM model
- [ ] Run experiment (from the paper)
- [ ] Show results
- [x] Code refactoring: move the class in separate files and the main code in a separate file

## ROUTMAP ESPERIMENTI
1. Replica risultati paper
   1. Confronta i diversi modelli 
   2. Prova alcune configurazioni
2. Ablation study su E
3. Esperimenti su dati non iid


Domande: 
- Notebook 
- Esperimenti 
- Aggiunte (early stopping, altri modelli, altri dataset, etc.)

RISPOSTE: 
- Early stopping inutile
- Compara risultati con paper cerca di replicare i risultati
- La cosa interessante e' vedere come cambia aumentando E il numero di iterazioni locali
- Perche' la differenza tra fedsgd e fedavg sta in E 
- Di fatto si dovrebbe vedere che aumentando troppo E si rompe la convergenza 
- E = 1 -> FedSGD 
- E > 1 -> FedAvg funziona ma per E troppo grande non funziona piu'
- Confronto dati iid e non iid e' interessante 
- Importante come gestisco e implemento i client paralleli (la parte piu' difficile da implementare)


