# Federated Learning

Federated learning is a machine learning paradigm that enables decentralized training of a shared model by multiple clients while preserving data privacy. The main idea behind this new paradigm is that each client trains a local model on its own data and then sends only the model updates to a central server, rather than sending the raw data. This allows the model to be trained on a large amount of data without compromising data privacy.

Federated learning was first proposed by Google in 2016 (McMahan et al., 2016) and has since been applied in various fields, such as healthcare (Hard et al., 2018), finance (Yoon et al., 2018), and natural language processing (Li et al., 2020).  For example, federated learning could be used to train a model that can make personalized recommendations for each user without requiring the raw data from each user to be shared with a central server. This is the mechanism by which the model is trained on data, while adhering to data privacy requirements.

Federated learning is a machine learning approach that offers numerous advantages over traditional centralized methods. Firstly, by leveraging distributed data stores, federated learning can scale to handle significantly larger datasets. Secondly, it prioritizes data privacy by avoiding the transmission of raw data to a central server. Finally, federated learning enables collaboration among multiple clients, allowing them to jointly train a shared model without compromising the security of their individual data. Overall, these benefits make federated learning a promising approach for machine learning in fields where data privacy is of the utmost importance.

Federated learning is a process in which a central server distributes a machine learning model to multiple devices. Each device trains the model on its local data and sends the updated model back to the central server. The central server then aggregates the updates from each device to improve the global model. This process is repeated until the model converges and can generate accurate predictions on new data. The key concepts within this process are:


* Client: refers to a device or edge node that holds a local dataset and actively participates in the training of the federated model.
* Server: represents the central entity that coordinates the training of the federated model and receives model updates from the clients to aggregate into a new version of the global model.
* Federated dataset: the collection of decentralized datasets from different clients that are used to train the federated model through collaborative learning.
* Federated model: a machine learning model that is trained on the federated dataset using federated learning to make accurate predictions on new data while preserving the privacy of each client's data.
* Federated optimization: refers to the process of training the federated model using the decentralized data and model updates from the clients, which enables the model to generalize better on unseen data while preserving the privacy of the clients.
* Aggregation: the process of combining the model updates received from the clients into a new version of the global model. This can be done using various methods such as weighted averaging or other approaches.
* Rounds: refer to the number of times a federated model is distributed among clients after performing an aggregation to train the model further. The process is repeated until the model converges and achieves a satisfactory level of accuracy.


Federated learning is a relatively new approach, and as such, there are few libraries available that have adapted to it. The main actors in this space are TensorFlow Federated, PySyft, OpenMined, and Flower. Of these, TensorFlow Federated is a notable mention, although it is currently only a theoretical approach, as it does not allow for the deployment of the solution and only simulates the federated space. In contrast, Flower allows for the distribution of federated learning, although the necessary modifications can be somewhat challenging. For this tutorial, we have chosen Flower due to its more user-friendly approach and potential for future use.



# Introduction to Flower (FLWR)

Flower is a Python library that offers tools for implementing the communication and coordination aspects of federated learning. Its design emphasizes ease of use and scalability. It's important to note that Flower is not a learning framework in itself, and as such, it wraps other machine learning frameworks like TensorFlow, PyTorch, or Scikit-learn in the communication layer to enable federated learning.

To use Flower for federated learning, you will need to install the library:


In [None]:
try:
    import flwr as fl
except ImportError as err:
    !pip install flwr[simulation]
    
try:
    import tensorflow as tf
except ImportError as err:
    !pip install tensorflow

When setting up a simulation environment, it's best to use the *simulation* keyword with the command to ensure the appropriate environment is loaded. On the other hand, if you plan to use Flower in a distributed setup, the command should be `!pip install flwr` on both the server and client devices. After installing the `flwr` package, you can import it into your Python code using the following statement:

In [None]:
import flwr as fl
import tensorflow as tf

FLWR provides a range of classes and functions that you can use to set up a federated learning environment, train and evaluate a model, and implement regular updates to the model. You can refer to the FLWR [documentation](https://flower.dev/docs/quickstart-tensorflow.html).  for more information. Before proceeding, it's important to note that the model you define must be serializable so that it can be sent through the network. Not all models are suitable for federated learning. For this example, we'll be using an Artificial Neural Network (ANN) based on TensorFlow, specifically Keras.

In [None]:
# Define a simple model using TensorFlow
def generate_ann():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        optimizer=tf.keras.optimizers.Adam(),
        metrics=['accuracy']
    )
    return model

As we will be using a Deep Learning model defined in TensorFlow, it's recommended to load the data into a `Dataset` class to enable the framework to leverage any available hardware acceleration (such as a GPU on the nodes). The following lines of code load the MNIST dataset and generate a dataset with batches of 32, which should be manageable on most modern machines.

In [None]:
# Load and partition the dataset that are present on each device
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))


Now, let's introduce the two pieces of the puzzle: the `Client` and the `Server`. Flower starts a `Server` to coordinate the client devices and perform the orchestration of the model. The server interacts with clients through an interface called `Client`. When the server selects a particular client for training, it sends training instructions over the network. The client receives those instructions and calls one of the Client methods to run your code, which in this case involves training the neural network that we defined earlier.

Flower provides a convenient class called NumPyClient, which simplifies the implementation of the Client interface when your workload uses Keras. The NumPyClient interface defines three methods that can be implemented in the following way:

In [None]:
#Create a class to contain the details of the client and be the interface
class MyClient(fl.client.NumPyClient):
    def __init__(self, net, train_dataset, test_dataset):
        self.model = net
        self.trainloader = train_dataset
        self.valloader = test_dataset
    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.trainloader, epochs=1, batch_size=32, steps_per_epoch=3)
        return self.model.get_weights(), len(x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.valloader)
        return loss, len(x_test), {"accuracy": float(accuracy)}

In the preceding code, we defined the required functions for the client in this particular case. With these functions in place, we can now start a client using the following code:

In [None]:
# Start the client
fl.client.start_numpy_client(server_address="[::]:8080", client=MyClient())

The string` [::]:8080` specifies the server to which the client should connect. In this case, as the code is being run on the same machine as the server, this address is sufficient. In a truly federated workload, the only thing that needs to be changed is the `server_address` to point the client to the correct server.

The other essential piece of the puzzle is the class that will contain the server. This will be in a separate file, for example server.py, and its contents should look something like this:


In [None]:
import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

In this particular case, we can run two clients and a server in separate terminals of the machine. Running two client instances is as simple as executing the `python client.py` command twice in separate terminals, while the server can be started with the `python server.py` command.

Upon starting the server, we should receive an output similar to:


```shell
INFO flower 2022-11-28 11:15:46,741 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2022-11-28 11:15:46,742 | server.py:72 | Getting initial parameters
INFO flower 2022-11-28 11:16:01,770 | server.py:74 | Evaluating initial parameters
INFO flower 2022-11-28 11:16:01,770 | server.py:87 | [TIME] FL starting
DEBUG flower 2022-11-28 11:16:12,341 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-11-28 11:21:17,235 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2022-11-28 11:21:17,512 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2022-11-28 11:21:29,628 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2022-11-28 11:21:29,696 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-11-28 11:25:59,917 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2022-11-28 11:26:00,227 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2022-11-28 11:26:11,457 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2022-11-28 11:26:11,530 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-11-28 11:30:43,389 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2022-11-28 11:30:43,630 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2022-11-28 11:30:53,384 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2022-11-28 11:30:53,384 | server.py:122 | [TIME] FL finished in 891.6143046000007
INFO flower 2022-11-28 11:30:53,385 | app.py:109 | app_fit: losses_distributed [(1, 2.3196680545806885), (2, 2.3202896118164062), (3, 2.1818180084228516)]
INFO flower 2022-11-28 11:30:53,385 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2022-11-28 11:30:53,385 | app.py:111 | app_fit: losses_centralized []
INFO flower 2022-11-28 11:30:53,385 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2022-11-28 11:30:53,442 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2022-11-28 11:31:02,848 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2022-11-28 11:31:02,848 | app.py:121 | app_evaluate: federated loss: 2.1818180084228516
INFO flower 2022-11-28 11:31:02,848 | app.py:125 | app_evaluate: results [('ipv4:127.0.0.1:31539', EvaluateRes(loss=2.1818180084228516, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.21610000729560852})), ('ipv4:127.0.0.1:31540', EvaluateRes(loss=2.1818180084228516, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.21610000729560852}))]
INFO flower 2022-11-28 11:31:02,848 | app.py:127 | app_evaluate: failures [] flower 2020-11-18 11:07:56,396 | app.py:77 | app_evaluate: failures []
```

With that, the first federated learning approach is completed. As you can see, the system goes through three rounds of fitting and evaluating on all clients before the results are retrieved, aggregated, and redistributed to the server.

### Exercise
Implement the client and server code in two separate files and compare the results with those presented here. Were your results similar?

`Answer here`:

# Updating parameters
The key element in this kind of approach is that the server sends the global model parameters to the client, and the client updates the local model with the parameters received from the server. It then trains the model on the local data, which changes the model parameters locally. After training, the updated model parameters are sent back to the server, or alternatively, only the gradients are sent back to the server, not the full model parameters.


In `flwr`, this communication is essentially done by two helper functions for loading and retrieving local parameters: `set_parameters` and `get_parameters`. This requirement fits well with non-state approaches such as **PyTorch** or **JAX**. As demonstrated in the previous example, `flwr` can also be used with **TensorFlow** or even **scikit-learn**.

As a result, the basic structure for any client using this library has the same format:


In [None]:
#DO NOT RUN is only an example
#This is an example in Pytrorch and you will use this code to adapt it to tensorflow a little lower
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader # Dataset for train
        self.valloader = valloader # Dataset to validate

    def get_parameters(self, config):
        return get_parameters(self.net) # To be implemented specific for the framework

    def fit(self, parameters, config):
        set_parameters(self.net, parameters) # also to be implemented specificly for the framework
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In Flower, clients can be created by extending either the `flwr.client.Client` or `flwr.client.NumPyClient` classes. In the previous example, we used `NumPyClient` because it is easier to implement and requires less code as a template. Along with the extended class, there are three main methods that need to be implemented:

* `get_parameters`: Returns the current local model parameters.
* `fit`: Receives model parameters from the server, trains the model parameters on the local data, and returns the (updated) model parameters to the server.
* `evaluate`: Receives model parameters from the server, evaluates the model parameters on the local data, and returns the evaluation result to the server.

As you can see, the `MyClient` class implemented in the previous example follows this same structure.


#### Be aware: 
Sometimes, especially when we are simulating multiple clients on a single device, it can be useful to use a function to create the client when it is required. This is particularly important in stateless frameworks, such as PyTorch, which can benefit from a more efficient implementation that creates clients only when they are required for training or evaluation. For example, the following code loads different examples for each client before discarding them:


In [None]:
#DO NOT RUN
# Some as above this is the implementation in Pytorch same case that above
def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE) # This should be adapted according to the framework

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    # Create a  single Flower client representing a single organization
    return FlowerClient(net, trainloader, valloader)

Note that `myClient` cannot be used in the same sense because of the state that it keeps internally through the function `generate_ann`. However, if this state is removed, it can be used in the same way.

The clients are now set up to load, fit, and evaluate. However, we need to integrate the results from the different clients. In Flower terminology, this is known as a strategy, such as the *Federated Average (FedAvg)* strategy. In a first approach, we can use the built-in implementations of the framework, although custom strategies can also be used. Let's see an example:


In [None]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # Sample 100% of available clients for training
        fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
        min_fit_clients=10,  # Never sample less than 10 clients for training
        min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
        min_available_clients=10,  # Wait until all 10 clients are available
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=10,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
)

This code corresponds to the script running on the server, and it uses the simulation function to test this approach on a single device with the previously mentioned optimization to avoid overloading the device. The code generates 10 clients and randomly selects all of them (`fraction_fit = 1.0`) to train the model on all of them. After receiving the updates from the clients, the server performs the aggregation strategy before returning the global model to the clients for the next 5 rounds.

One point to highlight is that the framework is not only going to manage the `losses_distributed`, but none of the other metrics. Due to the diverse treatment of those measures, the framework cannot accurately handle the aggregation of these metrics. Users need to tell the framework how to handle and aggregate these custom metrics.

The strategy will then call these functions whenever it receives fit or evaluates metrics from clients. The two possible functions are `fit_metrics_aggregation_fn` and `evaluate_metrics_aggregation_fn`. For example, the following code creates the weighted average, and the previous example can be adapted as follows:


In [None]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.5,
        min_fit_clients=10,
        min_evaluate_clients=5,
        min_available_clients=10,
        evaluate_metrics_aggregation_fn=weighted_average,  # put the metric aggregation for the evaluation
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
)


We will revisit the definition of custom strategies in the following unit to define our own strategy and attempt to minimize some of the challenges that federated learning must address.

### Exercise

To test our implementation, we need to run the simulation with the CIFAR-10 dataset in a simulated environment.

In [None]:
import tensorflow as tf

# Code to load the dataset
def load_datasets(n_clients):
    # Download and transform CIFAR-10 (train and test)
    cifar10 = tf.keras.datasets.cifar10
 
    # Distribute it to train and test set
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    #Normalize data
    #TODO
    
    #Prepare the datasets
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))


    # Split training set into 10 partitions to simulate the individual dataset
    train_partition_size = len(x_train) // n_clients
    test_partition_size = len(x_test) // n_clients
    
    #Randomize the datasets
    
    train_dataset = train_dataset.shuffle(10_000)
    test_dataset = train_dataset.shuffle(2_000)

    # Split each partition 
    train_ds = []
    test_ds = []
    for _ in range(n_clients):
        train_ds.append(train_dataset.take(train_partition_size))    
        train_dataset = train_dataset.skip(train_partition_size)
        test_ds.append(test_dataset.take(test_partition_size))
        test_dataset = test_dataset.skip(test_partition_size)

train_ds, test_ds = load_datasets()

#TODO Client, client_fn and simulation

# Aggregation

To conclude this lesson, let's take a closer look at the key point of these strategies, which is the aggregation algorithm. These algorithms are responsible for combining the updates from the clients to generate the global model, and they are defined in the strategies as we have seen. Generally speaking, there are several types of aggregation that can be used in federated learning (Reddi et. al, 2020).  

Here are the different types of aggregation that can be used in federated learning:

* Federated averaging (`flwr.server.strategy.FedAvg`): In this approach, each device computes an update to the model parameters based on its local data, and these updates are then averaged together to create the global model. This approach is simple and effective, but it can be sensitive to the size of the updates and the quality of the data on each device.

* Federated weighted averaging: This approach is similar to federated averaging, but each device's update is given a different weight based on the size of its data set or the quality of its data. This can help to give more influence to devices with larger or higher-quality data.

* Federated averaging with momentum (`flwr.server.strategy.FedAvgM`): This approach is similar to federated averaging, but it incorporates a momentum term in order to smooth out the updates and help the model converge more quickly.

* Federated stochastic gradient descent(`flwr.server.strategy.FedAdagrad`): In this approach, each device computes an update to the model parameters based on a small batch of its local data, rather than the entire data set. This can help to reduce the communication overhead and improve the convergence rate of the model.

* Federated ADAM (`flwr.server.strategy.FedAdam`): This approach is a variant of federated stochastic gradient descent that uses the ADAM optimization algorithm to adaptively adjust the learning rate based on the gradient and second moment estimates.



All of the previously mentioned aggregation methods, except for Federated Weighted Averaging, are implemented in the `flwr` framework and can be used with the different strategies. Additionally, there are other less common aggregation methods that can be employed. The choice of aggregation method will ultimately depend on the specific characteristics of the data and the requirements of the task at hand.


#### References
* Hard, A., Konečný, J., McMahan, H. B., Richemond-Barakat, C., Sivek, J. S., & Talwar, K. (2018). Federated learning: Strategies for improving communication efficiency. arXiv preprint arXiv:1812.02903.
* Li, Y., Bonawitz, K., & Talwar, K. (2020). Fedprox: An optimizer for communication-efficient federated learning. arXiv preprint arXiv:2002.04283.
* McMahan, H. B., Moore, E., Ramage, D., Hampson, S., & y Arcas, B. A. (2016). Communication-efficient learning of deep networks from decentralized data. arXiv preprint arXiv:1602.05629.
* Yoon, J., Hard, A., Konečný, J., McMahan, H. B., & Sohl-Dickstein, J. (2018). Federal regression: A simple and scalable method for heterogeneous federated learning. arXiv preprint arXiv:1812.03862.
* Reddi, S., Charles, Z., Zaheer, M., Garrett, Z., Rush, K., Konečný, J., Kumar, S. and McMahan, H.B., 2020. Adaptive federated optimization. arXiv preprint arXiv:2003.00295.