# Federated Learning

Federated learning is a machine learning paradigm that enables decentralized training of a shared model by multiple clients while preserving data privacy. The main idea behind this new paradigm is that each client trains a local model on its own data and then sends only the model updates to a central server, rather than sending the raw data. This allows the model to be trained on a large amount of data without compromising data privacy.

Federated learning was first proposed by Google in 2016 (McMahan et al., 2016) and has since been applied in various fields, such as healthcare (Hard et al., 2018), finance (Yoon et al., 2018), and natural language processing (Li et al., 2020).  For example, federated learning could be used to train a model that can make personalized recommendations for each user without requiring the raw data from each user to be shared with a central server. This is the mechanism by which the model is trained on data, while adhering to data privacy requirements.

Federated learning is a machine learning approach that offers numerous advantages over traditional centralized methods. Firstly, by leveraging distributed data stores, federated learning can scale to handle significantly larger datasets. Secondly, it prioritizes data privacy by avoiding the transmission of raw data to a central server. Finally, federated learning enables collaboration among multiple clients, allowing them to jointly train a shared model without compromising the security of their individual data. Overall, these benefits make federated learning a promising approach for machine learning in fields where data privacy is of the utmost importance.

Federated learning is a process in which a central server distributes a machine learning model to multiple devices. Each device trains the model on its local data and sends the updated model back to the central server. The central server then aggregates the updates from each device to improve the global model. This process is repeated until the model converges and can generate accurate predictions on new data. The key concepts within this process are:


* Client: refers to a device or edge node that holds a local dataset and actively participates in the training of the federated model.
* Server: represents the central entity that coordinates the training of the federated model and receives model updates from the clients to aggregate into a new version of the global model.
* Federated dataset: the collection of decentralized datasets from different clients that are used to train the federated model through collaborative learning.
* Federated model: a machine learning model that is trained on the federated dataset using federated learning to make accurate predictions on new data while preserving the privacy of each client's data.
* Federated optimization: refers to the process of training the federated model using the decentralized data and model updates from the clients, which enables the model to generalize better on unseen data while preserving the privacy of the clients.
* Aggregation: the process of combining the model updates received from the clients into a new version of the global model. This can be done using various methods such as weighted averaging or other approaches.
* Rounds: refer to the number of times a federated model is distributed among clients after performing an aggregation to train the model further. The process is repeated until the model converges and achieves a satisfactory level of accuracy.


Federated learning is a relatively new approach, and as such, there are few libraries available that have adapted to it. The main actors in this space are TensorFlow Federated, PySyft, OpenMined, and Flower. Of these, TensorFlow Federated is a notable mention, although it is currently only a theoretical approach, as it does not allow for the deployment of the solution and only simulates the federated space. In contrast, Flower allows for the distribution of federated learning, although the necessary modifications can be somewhat challenging. For this tutorial, we have chosen Flower due to its more user-friendly approach and potential for future use.



# Introduction to Flower (FLWR)

Flower is a Python library that offers tools for implementing the communication and coordination aspects of federated learning. Its design emphasizes ease of use and scalability. It's important to note that Flower is not a learning framework in itself, and as such, it wraps other machine learning frameworks like TensorFlow, PyTorch, or Scikit-learn in the communication layer to enable federated learning.

To use Flower for federated learning, you will need to install the library:


In [None]:
try:
    import flwr as fl
except ImportError:
    !pip install flwr[simulation] 
    import flwr as fl  # Import again after installation

try:
    import tensorflow as tf
except ImportError:
    !pip install tensorflow
    import tensorflow as tf  # Import again after installation 

When setting up a simulation environment, it's crucial to use the `[simulation]` keyword with the `!pip install flwr[simulation]` command. This ensures that the necessary dependencies for simulating a federated learning environment, such as the `flwr_simulation` package, are also installed.

**Important Note:** In a distributed setup, use `!pip install flwr` on both the server and all client devices to ensure consistency.

**Installation Guide:**

* For the most up-to-date and comprehensive installation instructions, refer to the official Flower documentation: [Link to official Flower documentation](https://flower.ai/)

After installing the `flwr` package, import it into your Python code as follows:

In [None]:
import flwr as fl
import tensorflow as tf

FLWR provides a range of classes and functions that you can use to set up a federated learning environment, train and evaluate a model, and implement regular updates to the model. For more information, refer to the FLWR [documentation](https://flower.dev/docs/quickstart-tensorflow.html). Before proceeding, it's important to note that the model you define must be serializable so that it can be sent through the network. Not all models are suitable for federated learning. For this example, we'll be using an Artificial Neural Network (ANN) based on TensorFlow, specifically Keras. These models are generally lightweight and well-suited for serialization compared to some other model types.

In [None]:
# Define a simple model using TensorFlow
def generate_ann():
    model = tf.keras.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(32, 32, 3)), #input data. 32x32 color images with 3 channels
            tf.keras.layers.Dense(64, activation="relu"), #hidden layer
            tf.keras.layers.Dense(64, activation="relu"),#hidden layer
            tf.keras.layers.Dense(10, activation="softmax"), #output layer. 10 classes
        ]
    )

    model.compile(
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        optimizer=tf.keras.optimizers.Adam(),
        metrics=["accuracy"],
    )
    return model

As we will be using a Deep Learning model defined in TensorFlow, it's recommended to load the data into a `Dataset` class to enable the framework to leverage any available hardware acceleration (such as a GPU on the nodes). However, due to some limitations of the framework to serialize the data, it has to be done manually with the following lines of code

In [None]:
import numpy as np
import tensorflow as tf

NUM_CLIENTS = 5

def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

def split_index(a, n):
    s = np.array_split(np.arange(len(a)), n)
    return s


# Code to load the dataset
def load_datasets(num_clients: int):
    # Distribute it to train and test set
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    # Normalize data
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0

    x_train, y_train = x_train[:10_000], y_train[:10_000]
    x_test, y_test = x_test[:1000], y_test[:1000]

    # Randomize the datasets
    x_train, y_train = unison_shuffled_copies(x_train, y_train)
    x_test, y_test = unison_shuffled_copies(x_test, y_test)

    # Split training set into NUM_CLIENTS partitions to simulate the individual dataset
    train_index = split_index(x_train, num_clients)
    test_index = split_index(x_test, num_clients)

    # Split each partition
    train_ds = []
    val_ds = []
    test_ds = []
    for cid in range(num_clients):
        val_size = len(train_index[cid]) // 10
        train_input_data, train_output_data = x_train[train_index[cid]], y_train[train_index[cid]]
        val_input_data, val_output_data = train_input_data[:val_size], train_output_data[:val_size]
        train_input_data, train_output_data = train_input_data[val_size:], train_output_data[val_size:]
        train_dataset = (train_input_data, train_output_data)
        val_dataset = (val_input_data, val_output_data)
        test_dataset = (x_test[test_index[cid]], y_test[test_index[cid]])
        train_ds.append(train_dataset)
        val_ds.append(val_dataset)
        test_ds.append(test_dataset)
    return train_ds, val_ds, test_ds


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Before proceeding, it's crucial to perform a basic test to validate the correctness of our approach. This involves training the model on the dataset of a single client. The purposes are:
1. **Checks Model Training Functionality**: It ensures that the model can be successfully trained on the data of an individual client.
Identifies
2. **Potential Issues Early**: It helps to detect and address any errors or unexpected behavior in the data loading or model training processes before moving on to the more complex federated learning implementation

In [None]:
model = generate_ann()

model.fit(trainloaders[0][0],trainloaders[0][1], epochs=1, batch_size=32, steps_per_epoch=3)
loss, accuracy = model.evaluate(valloaders[0][0], valloaders[0][1])
print(f"The resulting loss is {loss} for an accuracy in test of {accuracy*10:.4f}%")

**Important**: *The following example is thought to be executed in a terminal* 

Starting from that point, that it works on a particular data, the federated processs is going to split this process in several machines and for that it is going to require de definition of a couple of additional elements. Therefore, let's introduce the two pieces of the puzzle: the `Client` and the `Server`. 



Flower starts a `Server` to coordinate the client devices and perform the orchestration of the model. The server interacts with clients through an interface called `Client`. When the server selects a particular client for training, it sends training instructions over the network. The client receives those instructions and calls one of the Client methods to run your code, which in this case involves training the neural network that we defined earlier.

Flower provides a convenient class called NumPyClient, which simplifies the implementation of the Client interface when your workload uses Keras. The [NumPyClient](https://flower.ai/docs/framework/ref-api/flwr.client.NumPyClient.html#flwr.client.NumPyClient) interface defines three methods that can be implemented in the following way:

```python
#Create a class to contain the details of the client and be the interface
class MyClient(fl.client.NumPyClient):
    def __init__(self, net, train_dataset, test_dataset):
        self.model = net
        self.trainloader = train_dataset
        self.valloader = test_dataset
        
    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.trainloader[0],self.trainloader[1], epochs=1, batch_size=32, steps_per_epoch=3)
        return self.model.get_weights(), len(self.trainloader[0]), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.valloader[0], self.valloader[1])
        return loss, len(self.valloader[0]), {"accuracy": float(accuracy)}
```

In the preceding code, we defined the required functions for the client in this particular case. With these functions in place, we can now start a client using the following code:

```python 
# Start the client
model=generate_ann()

fl.client.start_client(server_address=f"localhost:8080", client=MyClient(model,trainloaders[0],valloaders[0]).to_client())
```

**Important**: In order to run the client you must need also a server running!!! You will executed both in separated terminals

The string`localhost:8080` specifies the server to which the client should connect. In this case, as the code is being run on the same machine as the server, this address is sufficient. In a truly federated workload, the only thing that needs to be changed is the `server_address` to point the client to the correct server.

Note that Jupyter usually runs on port 8080, **so you will need to use another available port if Jupyter server is running**.


The other essential piece of the puzzle is the class that will contain the server. This will be in a separate file, for example, server.py, and its contents should look something like this:


```python
import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))
```

##### **Important**
You can use another port for running the server if you fix the following parameter in the `start_server` function: `server_address="localhost:8080"`

In this particular case, we can run two clients and a server in separate terminals of the machine. Running two client instances is as simple as executing the `python client.py` command twice in separate terminals, while the server can be started with the `python server.py` command.

Upon starting the server, we should receive an output similar to:


```shell
INFO flwr 2023-03-01 14:58:16,353 | app.py:139 | Starting Flower server, config: ServerConfig(num_rounds=3, round_timeout=None)
INFO flwr 2023-03-01 14:58:16,362 | app.py:152 | Flower ECE: gRPC server running (3 rounds), SSL is disabled
INFO flwr 2023-03-01 14:58:16,362 | server.py:86 | Initializing global parameters
INFO flwr 2023-03-01 14:58:16,362 | server.py:270 | Requesting initial parameters from one random client
INFO flwr 2023-03-01 14:58:24,152 | server.py:274 | Received initial parameters from one random client
INFO flwr 2023-03-01 14:58:24,153 | server.py:88 | Evaluating initial parameters
INFO flwr 2023-03-01 14:58:24,153 | server.py:101 | FL starting
DEBUG flwr 2023-03-01 14:58:26,118 | server.py:215 | fit_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-03-01 14:58:27,041 | server.py:229 | fit_round 1 received 2 results and 0 failures
WARNING flwr 2023-03-01 14:58:27,076 | fedavg.py:242 | No fit_metrics_aggregation_fn provided
DEBUG flwr 2023-03-01 14:58:27,076 | server.py:165 | evaluate_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-03-01 14:58:27,565 | server.py:179 | evaluate_round 1 received 2 results and 0 failures
WARNING flwr 2023-03-01 14:58:27,565 | fedavg.py:273 | No evaluate_metrics_aggregation_fn provided
DEBUG flwr 2023-03-01 14:58:27,566 | server.py:215 | fit_round 2: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-03-01 14:58:28,015 | server.py:229 | fit_round 2 received 2 results and 0 failures
DEBUG flwr 2023-03-01 14:58:28,027 | server.py:165 | evaluate_round 2: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-03-01 14:58:28,364 | server.py:179 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-03-01 14:58:28,364 | server.py:215 | fit_round 3: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-03-01 14:58:28,755 | server.py:229 | fit_round 3 received 2 results and 0 failures
DEBUG flwr 2023-03-01 14:58:28,769 | server.py:165 | evaluate_round 3: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-03-01 14:58:29,184 | server.py:179 | evaluate_round 3 received 2 results and 0 failures
INFO flwr 2023-03-01 14:58:29,185 | server.py:144 | FL finished in 5.031599427999936
INFO flwr 2023-03-01 14:58:29,185 | app.py:202 | app_fit: losses_distributed [(1, 2.3956351280212402), (2, 2.426431179046631), (3, 2.3015435934066772)]
INFO flwr 2023-03-01 14:58:29,185 | app.py:203 | app_fit: metrics_distributed {}
INFO flwr 2023-03-01 14:58:29,186 | app.py:204 | app_fit: losses_centralized []

```

With that, the first federated learning approach is completed. As you can see, the system goes through three rounds of fitting and evaluating on all clients before the results are retrieved, aggregated, and redistributed to the server.

### Exercise
Implement the client and server code in separate files (**All the required files must be submitted together with the Notebook**). Next, execute a server and two clients from terminals. Finally, compare the results with those presented here. Were your results similar?


**Note**:  The answer is expected to contain the corresponding terminal output resulting from the execution of your code.

`Answer here`:

# Updating parameters
The key element in this kind of approach is that the server sends the global model parameters to the client, and the client updates the local model with the parameters received from the server. It then trains the model on the local data, which changes the model parameters locally. After training, the updated model parameters are sent back to the server, or alternatively, only the gradients are sent back to the server, not the full model parameters.


In `flwr`, this communication is essentially done by two helper functions for loading and retrieving local parameters: `set_parameters` and `get_parameters`. This requirement fits well with non-state approaches such as **PyTorch** or **JAX**. As demonstrated in the previous example, `flwr` can also be used with **TensorFlow** or even **scikit-learn**.

As a result, the basic structure for any client using this library has the same format:


In [None]:
from typing import List

# Utility functions for the most common operations
def get_parameters(net) -> List[np.array]:
    return net.get_weights()


def set_parameters(net, parameters: List[np.ndarray]):
    net.set_weights(parameters)
    return net


def train(net, trainloader, epochs: int):
    net.fit(trainloader[0], trainloader[1],
            epochs=epochs, batch_size=32, steps_per_epoch=3)
    return net


def test(net, testloader):
    loss, accuracy = net.evaluate(testloader[0], testloader[1])
    return loss, accuracy

# Class to contain a Client
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        self.net = set_parameters(self.net, parameters)
        self.net = train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        self.net = set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        print(f"[Client {self.cid}] loss:{loss}, Client {self.cid} accuracy:{accuracy}")
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In Flower, clients can be created by extending either the `flwr.client.Client` or `flwr.client.NumPyClient` classes. In the previous example, we used `NumPyClient` because it is easier to implement and requires less code as a template. Along with the extended class, there are three main methods that need to be implemented:

* `get_parameters`: Returns the current local model parameters.
* `fit`: Receives model parameters from the server, trains the model parameters on the local data, and returns the (updated) model parameters to the server.
* `evaluate`: Receives model parameters from the server, evaluates the model parameters on the local data, and returns the evaluation result to the server.

As you can see, the `MyClient` class implemented in the previous example follows this same structure and the diference is the *id* of the client which is stored for later convinience use in accesing the data.


#### Be aware: 
Sometimes, especially when we are simulating multiple clients on a single device, it can be useful to use a function to create the client when it is required. This is particularly important in stateless frameworks, such as PyTorch, which can benefit from a more efficient implementation that creates clients only when they are required for training or evaluation. For example, the following code loads different examples for each client before discarding them:


In [None]:
from flwr.common import Context
from flwr.client import ClientApp


def client_fn(context: Context):
    """Returns a FlowerClient containing its data partition."""
    
    # Create the model
    net = generate_ann()
    #get the identification
    partition_id = int(context.node_config["partition-id"])
    #Take the appropiate part of the dataset
    trainloader = trainloaders[int(partition_id)]
    valloader = valloaders[int(partition_id)]
    #Create and return the Client
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()

From versión 1.13 onward, in order to run the simulation it is manadatory that the clients are contained inse a `ClientApp`. This represent each one of the nodes that are going to be used and call later by the simulation function.

In [None]:
# Concstruct the ClientApp passing the client generation function
client_app = ClientApp(client_fn=client_fn)

Note that `myClient` cannot be used in the same sense because of the state that it keeps internally through the function `generate_ann`. However, if this state is removed, it can be used in the same waysimilarly the this CLientFlower.

The clients are now set up to load, fit, and evaluate. However, we need to integrate the results from the different clients. In Flower terminology, this is known as a strategy, such as the *Federated Average (FedAvg)* strategy. In a first approach, we can use the built-in implementations of the framework, although custom strategies can also be used. Additionally all those elements has to be set in an object in the Flower ecosystem with a `ServerApp` in order to run them.It has a similar signature to `client_fn` but, instead of returning a client object, it returns all the components needed to run the server-side logic in Flower. Let's see an example:


In [None]:
from flwr.server import ServerApp, ServerAppComponents

num_rounds = 3

def server_fn(context: Context):
    # instantiate the model
    model = generate_ann()
    params = get_parameters(model)# The federated model initial parameters
    del model # this is not require but it saves a good amount of memory
    # Convert model parameters to flwr.common.Parameters
    global_model_init = fl.common.ndarrays_to_parameters(params)


    # Create FedAvg strategy
    strategy = fl.server.strategy.FedAvg(
            fraction_fit=1.0,  # Sample 100% of available clients for training
            fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
            min_fit_clients=NUM_CLIENTS,  # Never sample less than NUM_CLIENTS clients for training
            min_evaluate_clients=NUM_CLIENTS//2,  # Never sample less than NUM_CLIENTS//2 clients for evaluation
            min_available_clients=NUM_CLIENTS,  # Wait until all NUM_CLIENTS clients are available
            initial_parameters=fl.common.ndarrays_to_parameters(params), # Initial parameters
    )

    # Define ServerConfig
    config=fl.server.ServerConfig(num_rounds=5)

    # Return the configuration and strategy for this server
    return ServerAppComponents(strategy=strategy, config=config)

# Create Server
server_app = ServerApp(server_fn=server_fn)
    

Now we have the two pieces of the puzzle. Therefore, the simulation engine can be called.

In [None]:
fl.simulation.run_simulation(
    server_app=server_app, client_app=client_app, num_supernodes=NUM_CLIENTS
)

This code corresponds to the script running on the server, and it uses the simulation function to test this approach on a single device with the previously mentioned optimization to avoid overloading the device. The code generates NUM_CLIENTS clients and randomly selects all of them (`fraction_fit = 1.0`) to train the model on all of them. After receiving the updates from the clients, the server performs the aggregation strategy before returning the global model to the clients for the next 3 rounds.

**Note**: *Be aware that the simulation is a really resource consuming task. You can get some errors and warnings linked to that. There are different [configurations](https://flower.ai/docs/framework/how-to-run-simulations.html) that you can apply to configure the available resources for simulating the process.*

One point to highlight is that the framework is not only going to manage the `losses_distributed`, but none of the other metrics. Due to the diverse treatment of those measures, the framework cannot accurately handle the aggregation of these metrics. Users need to tell the framework how to handle and aggregate these custom metrics.

The strategy will then call these functions whenever it receives fit or evaluates metrics from clients. The two possible functions are `fit_metrics_aggregation_fn` and `evaluate_metrics_aggregation_fn`. For example, the following code creates the weighted average, and the previous example can be adapted as follows:


In [None]:
from typing import Tuple, List
from flwr.common import Metrics

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

def server_fn(context: Context):
    # instantiate the model
    model = generate_ann()
    params = get_parameters(model)# The federated model initial parameters
    del model # this is not require but it saves a good amount of memory
    # Convert model parameters to flwr.common.Parameters
    global_model_init = fl.common.ndarrays_to_parameters(params)


    # Create FedAvg strategy
    strategy = fl.server.strategy.FedAvg(
            fraction_fit=1.0,  # Sample 100% of available clients for training
            fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
            min_fit_clients=NUM_CLIENTS,  # Never sample less than NUM_CLIENTS clients for training
            min_evaluate_clients=NUM_CLIENTS//2,  # Never sample less than NUM_CLIENTS//2 clients for evaluation
            min_available_clients=NUM_CLIENTS,  # Wait until all NUM_CLIENTS clients are available
            initial_parameters=fl.common.ndarrays_to_parameters(params), # Initial parameters
            evaluate_metrics_aggregation_fn=weighted_average,  # put the metric aggregation for the evaluation
    )

    # Define ServerConfig
    config=fl.server.ServerConfig(num_rounds=5)

    # Return the configuration and strategy for this server
    return ServerAppComponents(strategy=strategy, config=config)
    # Define ServerConfig
    config=fl.server.ServerConfig(num_rounds=5)

    # Return the configuration and strategy for this server
    return ServerAppComponents(strategy=strategy, config=config)
    
# Create Server
server_app = ServerApp(server_fn=server_fn)

fl.simulation.run_simulation(
    server_app=server_app, client_app=client_app, num_supernodes=NUM_CLIENTS
)


We will revisit the definition of custom strategies in the following unit to define our own strategy and attempt to minimize some of the challenges that federated learning must address.

### Exercise

Now is your turn, why not you try to run your own architecture with this approach. Beaware of the high requirements when we are in a simulated environment.

**Note**: The objective of this exercise is to bring together the various topics covered by this notebook. This cell should contain all the necessary code for autonomous execution using Flower simulation.

In [None]:
import tensorflow as tf

# Code to load the dataset
import numpy as np

NUM_CLIENTS = 5

def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

def split_index(a, n):
    s = np.array_split(np.arange(len(a)), n)
    return s


def generate_ann():
     #TODO
        
    return model


# Code to load the dataset
def load_datasets(num_clients: int):
    # Distribute it to train and test set
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    # Normalize data
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0

    x_train, y_train = x_train[:10_000], y_train[:10_000]
    x_test, y_test = x_test[:1000], y_test[:1000]

    # Randomize the datasets
    x_train, y_train = unison_shuffled_copies(x_train, y_train)
    x_test, y_test = unison_shuffled_copies(x_test, y_test)

    # Split training set into num_clients partitions to simulate the individual dataset
    train_index = split_index(x_train, num_clients)
    test_index = split_index(x_test, num_clients)

    # Split each partition
    train_ds = []
    val_ds = []
    test_ds = []
    for cid in range(num_clients):
        val_size = len(train_index[cid]) // 10
        train_input_data, train_output_data = x_train[train_index[cid]], y_train[train_index[cid]]
        val_input_data, val_output_data = train_input_data[:val_size], train_output_data[:val_size]
        train_input_data, train_output_data = train_input_data[val_size:], train_output_data[val_size:]
        train_dataset = (train_input_data, train_output_data)
        val_dataset = (val_input_data, val_output_data)
        test_dataset = (x_test[test_index[cid]], y_test[test_index[cid]])
        train_ds.append(train_dataset)
        val_ds.append(val_dataset)
        test_ds.append(test_dataset)
    return train_ds, val_ds, test_ds


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)


#TODO Client, client_fn, Server and simulation

# Aggregation

To conclude this lesson, let's take a closer look at the key point of these strategies, which is the aggregation algorithm. These algorithms are responsible for combining the updates from the clients to generate the global model, and they are defined in the strategies as we have seen. Generally speaking, there are several types of aggregation that can be used in federated learning (Reddi et. al, 2020).  

Here are the different types of aggregation that can be used in federated learning:

* Federated averaging (`flwr.server.strategy.FedAvg`): In this approach, each device computes an update to the model parameters based on its local data, and these updates are then averaged together to create the global model. This approach is simple and effective, but it can be sensitive to the size of the updates and the quality of the data on each device.

* Federated weighted averaging: This approach is similar to federated averaging, but each device's update is given a different weight based on the size of its data set or the quality of its data. This can help to give more influence to devices with larger or higher-quality data.

* Federated averaging with momentum (`flwr.server.strategy.FedAvgM`): This approach is similar to federated averaging, but it incorporates a momentum term in order to smooth out the updates and help the model converge more quickly.

* Federated stochastic gradient descent(`flwr.server.strategy.FedAdagrad`): In this approach, each device computes an update to the model parameters based on a small batch of its local data, rather than the entire data set. This can help to reduce the communication overhead and improve the convergence rate of the model.

* Federated ADAM (`flwr.server.strategy.FedAdam`): This approach is a variant of federated stochastic gradient descent that uses the ADAM optimization algorithm to adaptively adjust the learning rate based on the gradient and second moment estimates.



All of the previously mentioned aggregation methods, except for Federated Weighted Averaging, are implemented in the `flwr` framework and can be used with the different strategies. Additionally, there are other less common aggregation methods that can be employed. The choice of aggregation method will ultimately depend on the specific characteristics of the data and the requirements of the task at hand.


#### References
* Hard, A., Konečný, J., McMahan, H. B., Richemond-Barakat, C., Sivek, J. S., & Talwar, K. (2018). Federated learning: Strategies for improving communication efficiency. arXiv preprint arXiv:1812.02903.
* Li, Y., Bonawitz, K., & Talwar, K. (2020). Fedprox: An optimizer for communication-efficient federated learning. arXiv preprint arXiv:2002.04283.
* McMahan, H. B., Moore, E., Ramage, D., Hampson, S., & y Arcas, B. A. (2016). Communication-efficient learning of deep networks from decentralized data. arXiv preprint arXiv:1602.05629.
* Yoon, J., Hard, A., Konečný, J., McMahan, H. B., & Sohl-Dickstein, J. (2018). Federal regression: A simple and scalable method for heterogeneous federated learning. arXiv preprint arXiv:1812.03862.
* Reddi, S., Charles, Z., Zaheer, M., Garrett, Z., Rush, K., Konečný, J., Kumar, S. and McMahan, H.B., 2020. Adaptive federated optimization. arXiv preprint arXiv:2003.00295.