In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
sys.path.insert(0, module_path)

# New federated algorithm with `fluke`

This tutorial will guide you through the steps required to implement a new federated learning algorithm that can be tested with ``fluke``.

```{attention}
This tutorial does not go into the details of the implementation, but it provides a quick overview of the steps required to implement a new federated learning algorithm.
For a more in-depth guide on how to implement your own federated learning algorithm, please refer to [this section](../../add_algorithm.md).
```

Try this notebook: [![Open in Colab](https://img.shields.io/badge/Open_in_Colab-blue?style=flat-square&logo=google-colab&logoColor=yellow&labelColor=gray)
](https://colab.research.google.com/github/makgyver/fluke/blob/main/tutorials/fluke_custom_alg.ipynb)

## Install `fluke` (if not already done)

In [None]:
!pip install fluke-fl

## Implementing the server-side logic

To keep it simple, we use a very easy and not particulary smart :) example of a new FL algorithm.
Let's say we want define a new federated algorithm with these two characteristics:
- At each round, the server only selects two clients among the participants to be merged;
- When selected, a client will perform the local train for a number of epochs that is randomly chosen between 1 and the maximum number of epochs that is a hyperparameter.

Let's start with the server. Given the characteristics of the algorithm, the only thing the server does differently from the standard FedAvg server is to select only two clients to be merged. The rest of the logic is the same.

In [None]:
from typing import Collection
from fluke.client import Client
from fluke.server import Server
import numpy as np

class MyServer(Server):

    # we override the aggregate method to implement our aggregation strategy
    def aggregate(self, eligible: Collection[Client]) -> None:
        
        # eligible is a list of clients that participated in the last round
        # here we randomly select only two of them
        selected = np.random.choice(eligible, 2, replace=False)

        # we call the parent class method to aggregate the selected clients
        return super().aggregate(selected)

Easy! Most of the server's behaviour is the same as in `FedAvg` that is already implemented in `fluke.server.Server`.

## Implementing the client-side logic

Let's implement the client-side logic now. Also in this case we can start from the `FedAvg` client that is already implemented in `fluke.client.Client` and modify it to fit our needs.

In [None]:
class MyClient(Client):

    # we override the fit method to implement our training "strategy"
    def fit(self, override_local_epochs: int = 0) -> float:
        # we can override the number of local epochs and call the parent class method
        new_local_epochs = np.random.randint(1, self.hyper_params.local_epochs + 1)
        return super().fit(new_local_epochs)

## Implementing the new federated algorithm

Now, we only need to put everything together in a new class that inherits from `fluke.algorithms.CentralizedFL` specifying the server and client classes we just implemented.

In [None]:
from fluke.algorithms import CentralizedFL

class MyFLAlgorithm(CentralizedFL):

    def get_client_class(self) -> type[Client]:
        return MyClient

    def get_server_class(self) -> type[Server]:
        return MyServer

Everything is ready! Now we can test our new federated algorithm with `fluke`!

## Ready to test the new federated algorithm

The rest of the code is the same as in the [First steps with `fluke` API](fluke_quick_api.ipynb) tutorial.

In [None]:
from fluke.data import DataSplitter
from fluke.data.datasets import Datasets
from fluke import DDict
from fluke.utils.log import Log
from fluke.evaluation import ClassificationEval
from fluke import FlukeENV

settings = FlukeENV()
settings.set_seed(42) # we set a seed for reproducibility
settings.set_device("cpu") # we use the CPU for this example

dataset = Datasets.get("mnist", path="./data")

# we set the evaluator to be used by both the server and the clients
settings.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))

splitter = DataSplitter(dataset=dataset,
                        distribution="iid")

client_hp = DDict(
    batch_size=10,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(
      lr=0.01,
      momentum=0.9,
      weight_decay=0.0001),
    scheduler=DDict(
      gamma=1,
      step_size=1)
)

# we put together the hyperparameters for the algorithm
hyperparams = DDict(client=client_hp,
                    server=DDict(weighted=True),
                    model="MNIST_2NN")

Here is where the new federated algorithm comes into play.

In [None]:
algorithm = MyFLAlgorithm(n_clients=10, # 10 clients in the federation
                          data_splitter=splitter,
                          hyper_params=hyperparams)

logger = Log()
algorithm.set_callbacks(logger)

We only just need to run it!

In [None]:
logger.init()

In [None]:
algorithm.run(n_rounds=10, eligible_perc=0.5)