In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
sys.path.insert(0, module_path)

# Using your own data with `fluke`

This tutorial will guide you through the steps required to use a custom dataset with `fluke`.

Try this notebook: [![Open in Colab](https://img.shields.io/badge/Open_in_Colab-blue?style=flat-square&logo=google-colab&logoColor=yellow&labelColor=gray)
](https://colab.research.google.com/github/makgyver/fluke/blob/main/tutorials/fluke_custom_dataset.ipynb)

## Install `fluke` (if not already done)

In [None]:
!pip install fluke-fl

## Define your dataset function

In order to make your dataset ready to be used in `fluke`, you need to define a function that returns
a [DataContainer](../../fluke.data.datasets.md) object. A `DataContainer` is a simple class that
wraps your data which is expected to be already split into training, and test sets.

```{eval-rst}
    .. hint::
        You can have a dataset with no pre-defined test set. To make it work properly with ``fluke``, 
        you must set the training examples and labeles to two empty tensors. Then, in the configuration
        you must set ``keep_test`` to ``False``.
```

The following is an example of a dataset function that returns a random dataset with 100 examples (80 for training and 20 for testing).

In [2]:
from fluke.data.datasets import DataContainer
import torch

def MyDataset() -> DataContainer:

    # Random dataset with 100 2D points from 2 classes
    X = torch.randn(100, 2)
    y = torch.randint(0, 2, (100,))

    return DataContainer(X_train=X[:80],
                         y_train=y[:80],
                         X_test=X[80:],
                         y_test=y[80:],
                         num_classes=2)

## Using your dataset with `fluke` CLI

You can now use your dataset with `fluke` CLI. You need to specify in the configuration as the name
of the dataset the fully qualified name of the function. Let's say you have saved the function above in a file
called `my_dataset.py` and the function is called `my_dataset`, then you can use it as follows:

```yaml
dataset:
  name: my_dataset.MyDataset
  ...
```

Then, you can run `fluke` as usual:

```bash
fluke --config config.yaml federation fedavg.yaml
```

where `config.yaml` is the configuration file and `fedavg.yaml` is the configuration file for the federated averaging algorithm.

```{eval-rst}
    .. tip::
       Make sure to configure the algorithm with a model that is compatible with the dataset!
```

## Using your dataset with `fluke` API

This use case is really straightforward! Instead of using `Datasets.get` use your own function to get the dataset!!

Just for running the example, we define a tiny network that can work with our dataset.

In [3]:
import torch
from torch.functional import F

class MyMLP(torch.nn.Module):

    def __init__(self):
        super(MyMLP, self).__init__()
        self.fc1 = torch.nn.Linear(2, 3)
        self.fc2 = torch.nn.Linear(3, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

Now to run, for example, FedAVG on our dataset we do:

In [4]:
from fluke.data import DataSplitter
from fluke import DDict
from fluke.utils.log import Log
from fluke.evaluation import ClassificationEval
from fluke import FlukeENV

settings = FlukeENV()
settings.set_seed(42) # we set a seed for reproducibility
settings.set_device("cpu") # we use the CPU for this example

dataset = MyDataset() # Here it is our dataset

# we set the evaluator to be used by both the server and the clients
settings.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))

splitter = DataSplitter(dataset=dataset,
                        distribution="iid")

client_hp = DDict(
    batch_size=10,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(
      lr=0.01,
      momentum=0.9,
      weight_decay=0.0001),
    scheduler=DDict(
      gamma=1,
      step_size=1)
)

hyperparams = DDict(client=client_hp,
                    server=DDict(weighted=True),
                    model=MyMLP()) # we use our network :)

Here is where the new federated algorithm comes into play.

In [5]:
from fluke.algorithms.fedavg import FedAVG
algorithm = FedAVG(n_clients=2,
                   data_splitter=splitter,
                   hyper_params=hyperparams)

logger = Log()
algorithm.set_callbacks(logger)

We only just need to run it!

In [6]:
algorithm.run(n_rounds=10, eligible_perc=0.5)

Output()

Output()