In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
sys.path.insert(0, module_path)

# Using a custom model in `fluke`

This tutorial will guide you through the steps required to use a custom federated neural network in `fluke`.

## Install `fluke` (if not already done)

```bash
pip install fluke-fl
```

## Define your neural network

For the purpose of this tutorial, we will define a very simple neural network for the MNIST dataset. 
The network will have two hidden layers with ReLU activation function.

In [None]:
import torch
from torch.functional import F

class MyMLP(torch.nn.Module):

    def __init__(self):
        super(MyMLP, self).__init__()
        self.fc1 = torch.nn.Linear(28*28, 100)
        self.fc2 = torch.nn.Linear(100, 64)
        self.fc3 = torch.nn.Linear(64, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## FedAvg with your custom model

You are almost ready to use your custom model in `fluke`. 
The only thing you need to do is to set an instance of your `MyMLP` as model in the hyper-parameters of the algorithm. 

There is also another possibility, that is to provide as model the fully qualified name of your model class. 
This is useful because it allows to use a custom model with the `fluke` command line interface. 

To keep it simple, we are going to use FedAVG, but you can use any other algorithm available in `fluke` or even implement your own.

In [None]:
from fluke.data import DataSplitter
from fluke.data.datasets import Datasets
from fluke import DDict
from fluke.utils.log import Log
from fluke.algorithms.fedavg import FedAVG
from fluke.evaluation import ClassificationEval
from fluke import GlobalSettings

settings = GlobalSettings()
settings.set_seed(42) # we set a seed for reproducibility
settings.set_device("cpu") # we use the CPU for this example

dataset = Datasets.get("mnist", path="./data")

# we set the evaluator to be used by both the server and the clients
settings.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))

splitter = DataSplitter(dataset=dataset,
                        distribution="iid")

client_hp = DDict(
    batch_size=10,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(
        lr=0.01,
        momentum=0.9,
        weight_decay=0.0001),
    scheduler=DDict(
        gamma=1,
        step_size=1)
)

Here is where you must set the model in the hyper-parameters.

In [None]:
hyperparams = DDict(client=client_hp,
                    server=DDict(weighted=True),
                    model=MyMLP()) # or model="__main__.MyMLP"
                                   # or model="mymodule.MyMLP" if the model is in a module called mymodule

Finally, let's initialize the algorithm and run it.

In [None]:
algorithm = FedAVG(n_clients=10, # 10 clients in the federation
                   data_splitter=splitter,
                   hyper_params=hyperparams)

logger = Log()
algorithm.set_callbacks(logger)

In [None]:
algorithm.run(n_rounds=10, eligible_perc=0.5)