# Custom evaluation with `fluke`

This tutorial will guide you through the steps required to implement a new evaluation that can be tested with `fluke`.

Try this notebook: [![Open in Colab](https://img.shields.io/badge/Open_in_Colab-blue?style=flat-square&logo=google-colab&logoColor=yellow&labelColor=gray)
](https://colab.research.google.com/github/makgyver/fluke/blob/main/tutorials/fluke_custom_eval.ipynb)

## Install `fluke` (if not already done)

```bash
pip install fluke-fl
```

## Weighted accuracy

In this tutorial, we will show how to implement a metric that is quite common in Personalized Federated Learning in `fluke`! In particular, a common technique for evaluating the local model is through using a balanced test set, weighting the accuracy based on the number of samples of each class. Intuitively, the weighted accuracy takes into account the number of samples for each class, allowing to lower the penalty if an error occurs in classifying a less frequent class. This metric is taken from *Tackling Data Heterogeneity in Federated Learning with Class Prototypes*, Dai et al. and it is defined as follows:

```{eval-rst}

.. math::

    acc_i = \frac{\sum_{x_j,y_j\in \mathcal{D}_{test}}\alpha_i(y_j)\mathbb{1}(y_j = \hat{y}_j)}{\sum_{x_j,y_j\in \mathcal{D}_{test}}\alpha_i(y_j)}



where :math:`\alpha_i(\cdot)` is a positive valued function. It is defined as the probability that the sample y is from class c in the :math:`i^{th}` client. Notice that, for :math:`\alpha_i(\cdot) = 1` we obtain the traditional accuracy. In this tutorial, we will interpret :math:`\alpha_i(\cdot)` as the proportion of the local samples of the class :math:`y` over all the sample of that client. Specifically, we calculate the aforementioned coefficient for client :math:`i` and class :math:`y_j` as :math:`\alpha_i(y_j) = \frac{Y^i_j}{Y^i}`, where :math:`Y^i_j` is the number of samples of class :math:`y_j` for client :math:`i` (training set), and :math:`Y^i` is the total number of examples of client :math:`i`.

In our case :math:`\mathcal{D}_{test}` will be the dataset on the server, that is (usually) the original test set of the dataset. 

```

## Implementing the server-side logic

Notice that `server.evaluation` is called in `server.fit` with only two arguments (the evaluator and the eligible clients). As a consequence, if we want to modify the `server.evaluate` to take into account the class weights, we should modify the `server.fit` as well. However, this is too verbose. The most straightforward solution is to not modify `server.fit` and the input arguments of `server.evaluate` function, but modify the evaluator `evaluator.evaluate` input arguments, adding the class weight. 

In [None]:
import numpy as np
import torch

from fluke.data import FastDataLoader  # NOQA
from fluke.evaluation import Evaluator  # NOQA
from fluke.client import Client  # NOQA
from fluke.server import Server  # NOQA

class MyServer(Server):

    def evaluate(self, 
                 evaluator: Evaluator,
                 test_set: FastDataLoader) -> dict[str, float]:
        if test_set is not None:
            return evaluator.evaluate(self.rounds + 1, 
                                      self.model, 
                                      test_set, 
                                      device=self.device, 
                                      weights=torch.ones(evaluator.n_classes))
        return {}
    

## Implementing the client-side logic

Following the same logic as the server, we modify the `evaluator.evaluate` instead of the whole `client.local_update` and the inputs of `client.evaluate`.

In [None]:
from torch.nn import Module
from typing import Any

from fluke.utils import OptimizerConfigurator  # NOQA

class MyClient(Client):

    def __init__(self,
                 index: int,
                 train_set: FastDataLoader,
                 test_set: FastDataLoader,
                 optimizer_cfg: OptimizerConfigurator,
                 loss_fn: Module,
                 local_epochs: int = 3,
                 **kwargs: dict[str, Any]):
        super().__init__(index,
                 train_set,
                 test_set,
                 optimizer_cfg,
                 loss_fn,
                 local_epochs,
                 **kwargs)
        self.class_weights = torch.bincount(self.train_set.tensors[1]).float()
        self.class_weights /= self.train_set.size

        
    def evaluate(self, 
                 evaluator: Evaluator,
                 test_set: FastDataLoader) -> dict[str, float]: 
        if self.model is not None:
            return evaluator.evaluate(self._last_round, 
                                      self.model, 
                                      test_set, 
                                      device=self.device, 
                                      weights=self.class_weights)
        return {}

## Implementing the metric

In the following, we start from the `classification` metric present in `eval.py` and modify it, taking into account the weight for each class. As a sanity check, in the global evaluation `accuracy` and `weighted accuracy` will be the same.

In [None]:
from typing import Iterable, Optional, Union
import numpy as np
import torch
from torchmetrics import Accuracy

from fluke.utils import clear_cuda_cache  # NOQA


class WeightedClassificationEval(Evaluator):

    def __init__(self, eval_every: int, n_classes: int):
        super().__init__(eval_every=eval_every)
        self.n_classes: int = n_classes

    def evaluate(self,
                 round: int,
                 model: torch.nn.Module,
                 eval_data_loader: Union[FastDataLoader,
                                         Iterable[FastDataLoader]],
                 loss_fn: Optional[torch.nn.Module] = None,
                 device: torch.device = torch.device("cpu"),
                 weights: torch.tensor = None) -> dict:
        

        if round % self.eval_every != 0:
            return {}

        if (model is None) or (eval_data_loader is None):
            return {}

        model.eval()
        model.to(device)
        task = "multiclass"  # if self.n_classes >= 2 else "binary"
        accs, losses = [], []
        true_weights, pred_weights, mask = [], [], []
        weight_accs = []
        loss, cnt = 0, 0

        if not isinstance(eval_data_loader, list):
            eval_data_loader = [eval_data_loader]

        for data_loader in eval_data_loader:
            accuracy = Accuracy(task=task,
                                num_classes=self.n_classes,
                                top_k=1,
                                average="micro")
            loss = 0
            for X, y in data_loader:
                X, y = X.to(device), y.to(device)
                with torch.no_grad():
                    y_hat = model(X)
                    if loss_fn is not None:
                        loss += loss_fn(y_hat, y).item()
                    y_hat = torch.max(y_hat, dim=1)[1]
                true_weights.append(weights[y])
                pred_weights.append(weights[y_hat])
                mask.append(torch.eq(y, y_hat))
                accuracy.update(y_hat.cpu(), y.cpu())
               
            
            true_weights = torch.cat(true_weights, dim=0)
            pred_weights = torch.cat(pred_weights, dim=0)
            mask = torch.cat(mask, dim=0)
            pred_weights = pred_weights*mask
            weight_accs.append(pred_weights.sum().item() / true_weights.sum().item())
            
            cnt += len(data_loader)
            accs.append(accuracy.compute().item())
            losses.append(loss / cnt)

        model.cpu()
        clear_cuda_cache()

        result = {
            "accuracy":  np.round(sum(accs) / len(accs), 5).item(),
            "weighted_accuracy":  np.round(sum(weight_accs) / len(weight_accs), 5).item(),
        }
        if loss_fn is not None:
            result["loss"] = np.round(sum(losses) / len(losses), 5).item()

        return result

    def __str__(self) -> str:
        return f"{self.__class__.__name__}(eval_every={self.eval_every}" + \
               f", n_classes={self.n_classes})[accuracy, weight_acc]"

    def __repr__(self) -> str:
        return str(self)

## Implementing the new metric

Now, we are ready to test our metric!

In [None]:
from fluke.algorithms import CentralizedFL

class MyFLAlgorithm(CentralizedFL):

    def get_client_class(self) -> type[Client]:
        return MyClient

    def get_server_class(self) -> type[Server]:
        return MyServer

## Ready to test the new federated algorithm

The rest of the code is the similar to the [First steps with `fluke` API](fluke_quick_api.ipynb) tutorial. We just replace `ClassificationEval` with our custom evaluation.

In [None]:
from fluke.data import DataSplitter
from fluke.data.datasets import Datasets
from fluke import DDict
from fluke.utils.log import Log
from fluke import FlukeENV

env = FlukeENV()
env.set_seed(42) # we set a seed for reproducibility
env.set_device("cpu") # we use the CPU for this example
# we set the evaluation configuration
env.set_eval_cfg(DDict(pre_fit=True, post_fit=True)) 

# we set the evaluator to be used by both the server and the clients
env.set_evaluator(WeightedClassificationEval(eval_every=1, n_classes=10))

dataset = Datasets.get("mnist", path="./data")
splitter = DataSplitter(dataset=dataset, distribution="iid")

client_hp = DDict(
    batch_size=10,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(
      lr=0.01,
      momentum=0.9,
      weight_decay=0.0001),
    scheduler=DDict(
      gamma=1,
      step_size=1)
)

# we put together the hyperparameters for the algorithm
hyperparams = DDict(client=client_hp,
                    server=DDict(weighted=True),
                    model="MNIST_2NN")

Here is where the new federated algorithm comes into play.

In [None]:
algorithm = MyFLAlgorithm(n_clients=10, # 10 clients in the federation
                          data_splitter=splitter,
                          hyper_params=hyperparams)

logger = Log()
algorithm.set_callbacks(logger)

We only just need to run it!

In [None]:
algorithm.run(n_rounds=100, eligible_perc=0.5)