# FLSim Tutorial: Image classification with CIFAR-10

## Introduction

In this tutorial, we will train a simple CNN image classifier on CIFAR-10 with federated learning using FLSim.

### Prerequisites

To get the most of this tutorial, you should be comfortable with training machine learning models with **PyTorch** and familiar with the concept of **federated learning (FL)**. If you are unfamimiliar with either of them or could use a refresher, please take a look at the following resources before proceeding with the tutorial:

- McMahan & Ramage (2017): [Federated Learning: Collaborative Machine Learning without Centralized Training Data](https://ai.googleblog.com/2017/04/federated-learning-collaborative.html). A short blog post from Google AI introducing the main idea of FL in a beginner-friendly way.
- McMahan et al. (2017): [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/pdf/1602.05629.pdf). This paper first proposes the approach of federated learning. The described algorithm is now known as federated averaging (or FedAvg for short).
- PyTorch has [extensive tutorials](https://pytorch.org/tutorials/) on their website. In particular, take a look at their [image classification tutorial using CIFAR-10](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).

Now that you're familiar with PyTorch and FL, let's move on!

### Objectives 

By the end of this tutorial, we will have learnt how to

1. Build a data pipeline for federated learning with FLSim,
2. Create an image classification model compatible with FL training,
3. Set hyperparameters for FL training, 
4. Create a metrics reporter to collect metrics, and
5. Launch an FL training flow using FLSim.

## Training an image classifier with FLSim

### Prerequisite
First, let's install flsim via pip with the command below.

In [None]:
!pip install --quiet flsim

[K     |████████████████████████████████| 304 kB 5.5 MB/s 
[K     |████████████████████████████████| 114 kB 42.7 MB/s 
[K     |████████████████████████████████| 145 kB 48.6 MB/s 
[K     |████████████████████████████████| 74 kB 2.8 MB/s 
[K     |████████████████████████████████| 112 kB 51.2 MB/s 
[K     |████████████████████████████████| 596 kB 37.0 MB/s 
[?25h  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone


In [None]:
USE_CUDA = True
LOCAL_BATCH_SIZE = 32
EXAMPLES_PER_USER = 500
IMAGE_SIZE = 32

### 0. About the dataset

For this tutorial, we will use the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). The CIFAR-10 dataset consists of 60k 3x32x32 3-channel color images with 32x32 pixels from 10 classes, with 6k images per class. 
There are 50k training images (5k training images per class) and 10k test images (1k test images per class).
The classes are ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, and ‘truck’.

We can get the CIFAR-10 dataset from `torchvision.datasets`.

![img](https://pytorch.org/tutorials/_images/cifar10.png)

In [None]:
from torchvision.datasets.cifar import CIFAR10

### 1. Data pipeline

First, let's define how to build the data pipeline for federated learning:

1. We create data transforms and training, eval, and test datasets. This step is identical to preparing data in non-federated learning.

There are a few extra steps to enable training with federated learning. In particular, we need to

2. Create a sharder, which defines a mapping from examples in the training data to clients. In other words, **a sharder groups rows of data into client datasets**, returning a list of list of examples. FLSim provides a number of sharding strategies such as random or column-based sharding. In this tutorial, we use sequential sharding, which assigns the first `examples_per_user` rows to user 0, the second `examples_per_user` rows to user 1, etc. 

3. Create a data loader, which will shard and batchify training, eval, and test data. For each dataset, the data loader first assigns rows to clients using the sharder and then splits each client's data into batches of size `batch_size`. We choose not to drop the last batch.

4. Lastly, wrap the data loader with a data provider and return it. The data provider creates clients from the groupings in the data loader and adds metadata (e.g. number of examples/batches). Our data is now formatted such that the trainer will accept it.

Note that the concept of a client or device only applies to the training data, the eval and test set data identical to non-federated learning.

In [None]:
import random
from typing import Any, Dict, Generator, Iterable, Iterator, List, Tuple

import torch
from flsim.data.data_provider import IFLDataProvider, IFLUserData
from flsim.data.data_sharder import FLDataSharder
from flsim.interfaces.data_loader import IFLDataLoader
from flsim.utils.data.data_utils import batchify
from torchvision.datasets import VisionDataset
from tqdm import tqdm


def collate_fn(batch: Tuple) -> Dict[str, Any]:
    feature, label = batch
    return {"features": feature, "labels": label}


class DataLoader(IFLDataLoader):
    SEED = 2137
    random.seed(SEED)

    def __init__(
        self,
        train_dataset: VisionDataset,
        eval_dataset: VisionDataset,
        test_dataset: VisionDataset,
        sharder: FLDataSharder,
        batch_size: int,
        drop_last: bool = False,
        collate_fn=collate_fn,
    ):
        assert batch_size > 0, "Batch size should be a positive integer."
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sharder = sharder
        self.collate_fn = collate_fn

    def fl_train_set(self, **kwargs) -> Iterable[Dict[str, Generator]]:
        rank = kwargs.get("rank", 0)
        world_size = kwargs.get("world_size", 1)
        yield from self._batchify(self.train_dataset, self.drop_last, world_size, rank)

    def fl_eval_set(self, **kwargs) -> Iterable[Dict[str, Generator]]:
        yield from self._batchify(self.eval_dataset, drop_last=False)

    def fl_test_set(self, **kwargs) -> Iterable[Dict[str, Generator]]:
        yield from self._batchify(self.test_dataset, drop_last=False)

    def _batchify(
        self,
        dataset: VisionDataset,
        drop_last: bool = False,
        world_size: int = 1,
        rank: int = 0,
    ) -> Generator[Dict[str, Generator], None, None]:
        data_rows: List[Dict[str, Any]] = [self.collate_fn(batch) for batch in dataset]
        for index, (_, user_data) in enumerate(self.sharder.shard_rows(data_rows)):
            batch = {}
            keys = user_data[0].keys()
            for key in keys:
                attribute = {
                    key: batchify(
                        [row[key] for row in user_data],
                        self.batch_size,
                        drop_last,
                    )
                }
                batch = {**batch, **attribute}
            yield batch


class UserData(IFLUserData):
    def __init__(self, user_data: Dict[str, Generator]):
        self._user_batches = []
        self._num_batches = 0
        self._num_examples = 0
        for features, labels in zip(user_data["features"], user_data["labels"]):
            self._num_batches += 1
            self._num_examples += UserData.get_num_examples(labels)
            self._user_batches.append(UserData.fl_training_batch(features, labels))

    def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
        """
        Iterator to return a user batch data
        """
        for batch in self._user_batches:
            yield batch

    def num_examples(self) -> int:
        """
        Returns the number of examples
        """
        return self._num_examples

    def num_batches(self) -> int:
        """
        Returns the number of batches
        """
        return self._num_batches

    @staticmethod
    def get_num_examples(batch: List) -> int:
        return len(batch)

    @staticmethod
    def fl_training_batch(
        features: List[torch.Tensor], labels: List[float]
    ) -> Dict[str, torch.Tensor]:
        return {"features": torch.stack(features), "labels": torch.Tensor(labels)}


class DataProvider(IFLDataProvider):
    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.train_users = self._create_fl_users(data_loader.fl_train_set())
        self.eval_users = self._create_fl_users(data_loader.fl_eval_set())
        self.test_users = self._create_fl_users(data_loader.fl_test_set())

    def user_ids(self) -> List[int]:
        return list(self.train_users.keys())

    def num_users(self) -> int:
        return len(self.train_users)

    def get_user_data(self, user_index: int) -> IFLUserData:
        if user_index in self.train_users:
            return self.train_users[user_index]
        else:
            raise IndexError(
                f"Index {user_index} is out of bound for list with len {self.num_users()}"
            )

    def train_data(self) -> Iterable[IFLUserData]:
        for user_data in self.train_users.values():
            yield user_data

    def eval_data(self) -> Iterable[Dict[str, torch.Tensor]]:
        for user_data in self.eval_users.values():
            for batch in user_data:
                yield batch

    def test_data(self) -> Iterable[Dict[str, torch.Tensor]]:
        for user_data in self.test_users.values():
            for batch in user_data:
                yield batch

    def _create_fl_users(self, iterator: Iterator) -> Dict[int, IFLUserData]:
        return {
            user_index: UserData(user_data)
            for user_index, user_data in tqdm(
                enumerate(iterator), desc="Creating FL User", unit="user"
            )
        }


In [None]:
from flsim.data.data_sharder import SequentialSharder
from torchvision import transforms


def build_data_provider(local_batch_size, examples_per_user):

    # 1. Create training, eval, and test datasets like in non-federated learning.
    transform = transforms.Compose(
        [
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    train_dataset = CIFAR10(
        root="./cifar10", train=True, download=True, transform=transform
    )
    test_dataset = CIFAR10(
        root="./cifar10", train=False, download=True, transform=transform
    )

    # 2. Create a sharder, which maps samples in the training data to clients.
    sharder = SequentialSharder(examples_per_shard=examples_per_user)

    # 3. Shard and batchify training, eval, and test data.
    fl_data_loader = DataLoader(
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        test_dataset=test_dataset,
        sharder=sharder,
        batch_size=local_batch_size,
        drop_last=False,
    )

    # 4. Wrap the data loader with a data provider.
    data_provider = DataProvider(fl_data_loader)
    print(f"Clients in total: {data_provider.num_users()}")
    return data_provider


### 2. Create the model

Now, let's see how we can create a model that is compatible with FL-training.

1. Define a standard, non-FL image classification pytorch `nn.Module`; in this tutorial we use a simple CNN.

2. Create a `torch.device` and choose where the model will be allocated (CUDA or CPU).

3. Wrap the pytorch module with the FLSim `FLModel`, an abstracted version of a FL friendly model class.

4. Move the model to GPU and enable CUDA if desired.

In this step, we create a standard nn.Module with 4 convolution layers, group norm and a linear layer. 

In [None]:
import torch.nn as nn

class SimpleConvNet(nn.Module):
    def __init__(self, in_channels, num_classes, dropout_rate=0):
        super(SimpleConvNet, self).__init__()
        self.out_channels = 32
        self.stride = 1
        self.padding = 2
        self.layers = []
        in_dim = in_channels
        for _ in range(4):
          self.layers.append(nn.Conv2d(in_dim, self.out_channels, 3, self.stride, self.padding))
          in_dim = self.out_channels
        self.layers = nn.ModuleList(self.layers)

        self.gn_relu = nn.Sequential(
            nn.GroupNorm(self.out_channels, self.out_channels, affine=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        num_features = (
            self.out_channels
            * (self.stride + self.padding)
            * (self.stride + self.padding)
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        for conv in self.layers:
            x = self.gn_relu(conv(x))

        x = x.view(-1, self.num_flat_features(x))
        x = self.fc(self.dropout(x))
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


# 1. Define our model, a simple CNN.
model = SimpleConvNet(in_channels=3, num_classes=10)
model

SimpleConvNet(
  (layers): ModuleList(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  )
  (gn_relu): Sequential(
    (0): GroupNorm(32, 32, eps=1e-05, affine=True)
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dropout): Dropout(p=0, inplace=False)
  (fc): Linear(in_features=288, out_features=10, bias=True)
)

Below, we wrap the pytorch module within an IFLModel. An IFLModel is an abstracted version of a FL friendly model class. It handles metrics, local training batch creation, training forward and eval forward. 

In [None]:
import torch
from typing import Optional
import torch.nn.functional as F
from flsim.interfaces.model import IFLModel
from flsim.utils.simple_batch_metrics import FLBatchMetrics


class FLModel(IFLModel):
    def __init__(self, model: nn.Module, device: Optional[str] = None):
        self.model = model
        self.device = device

    def fl_forward(self, batch) -> FLBatchMetrics:
        features = batch["features"]  # [B, C, 28, 28]
        batch_label = batch["labels"]
        stacked_label = batch_label.view(-1).long().clone().detach()
        if self.device is not None:
            features = features.to(self.device)

        output = self.model(features)

        if self.device is not None:
            output, batch_label, stacked_label = (
                output.to(self.device),
                batch_label.to(self.device),
                stacked_label.to(self.device),
            )

        loss = F.cross_entropy(output, stacked_label)
        num_examples = self.get_num_examples(batch)
        output = output.detach().cpu()
        stacked_label = stacked_label.detach().cpu()
        del features
        return FLBatchMetrics(
            loss=loss,
            num_examples=num_examples,
            predictions=output,
            targets=stacked_label,
            model_inputs=[],
        )

    def fl_create_training_batch(self, **kwargs):
        features = kwargs.get("features", None)
        labels = kwargs.get("labels", None)
        return UserData.fl_training_batch(features, labels)

    def fl_get_module(self) -> nn.Module:
        return self.model

    def fl_cuda(self) -> None:
        self.model = self.model.to(self.device)

    def get_eval_metrics(self, batch) -> FLBatchMetrics:
        with torch.no_grad():
            return self.fl_forward(batch)

    def get_num_examples(self, batch) -> int:
        return UserData.get_num_examples(batch["labels"])


# 2. Choose where the model will be allocated.
cuda_enabled = torch.cuda.is_available() and USE_CUDA
device = torch.device(f"cuda:{0}" if cuda_enabled else "cpu")

# 3. Wrap the model in FLModel.
global_model = FLModel(model, device)

# 4. Enable CUDA if desired.
if cuda_enabled:
    global_model.fl_cuda()

ModuleNotFoundError: ignored

### 3. Metrics Reporting

After we created our data pipeline and FL model, we then define our metrics reporter. The metrics reporter allows us to collect metrics and log them onto tensorboard. 

There are three functions that we care about: 

1. `compare_metrics`: This function compares the current eval metric that is returned from `create_eval_metrics` which we will define below. 

2. `compute_scores`: This function calculates the metrics that we care both. In this case, we would like to report the top1 accuracy. 

3. `create_eval_metrics`: This function creates the eval metrics dictionary that can be used by `compare_metrics` above.

In [None]:
from typing import Any, Dict, List, Optional

import torch
from flsim.common.timeline import Timeline
from flsim.interfaces.metrics_reporter import Channel, TrainingStage
from flsim.metrics_reporter.tensorboard_metrics_reporter import FLMetricsReporter
from flsim.utils.fl.stats import (
    AverageType,
)

class MetricsReporter(FLMetricsReporter):
    ACCURACY = "Accuracy"

    def __init__(
        self,
        channels: List[Channel],
        target_eval: float = 0.0,
        window_size: int = 5,
        average_type: str = "sma",
        log_dir: Optional[str] = None,
    ):
        super().__init__(channels, log_dir)
        self.set_summary_writer(log_dir=log_dir)
        self._round_to_target = float(1e10)

    def compare_metrics(self, eval_metrics, best_metrics):
        print(f"Current eval accuracy: {eval_metrics}%, Best so far: {best_metrics}%")
        if best_metrics is None:
            return True

        current_accuracy = eval_metrics.get(self.ACCURACY, float("-inf"))
        best_accuracy = best_metrics.get(self.ACCURACY, float("-inf"))
        return current_accuracy > best_accuracy

    def compute_scores(self) -> Dict[str, Any]:
        # compute accuracy
        correct = torch.Tensor([0])
        for i in range(len(self.predictions_list)):
            all_preds = self.predictions_list[i]
            pred = all_preds.data.max(1, keepdim=True)[1]

            assert pred.device == self.targets_list[i].device, (
                f"Pred and targets moved to different devices: "
                f"pred >> {pred.device} vs. targets >> {self.targets_list[i].device}"
            )
            if i == 0:
                correct = correct.to(pred.device)

            correct += pred.eq(self.targets_list[i].data.view_as(pred)).sum()

        # total number of data
        total = sum(len(batch_targets) for batch_targets in self.targets_list)

        accuracy = 100.0 * correct.item() / total
        return {self.ACCURACY: accuracy}

    def create_eval_metrics(
        self, scores: Dict[str, Any], total_loss: float, **kwargs
    ) -> Any:
        timeline: Timeline = kwargs.get("timeline", Timeline(global_round=1))
        stage: TrainingStage = kwargs.get("stage", None)
        accuracy = scores[self.ACCURACY]
        return {
            self.ACCURACY: accuracy
        }

### 4. Hyperparameters

We can represent the hyperparameters for FL training in a JSON config. In particular, we specify a FedAvg implementation with 10 users per round.

This config is passed to the FL trainer.

In [None]:
import flsim.configs
from flsim.utils.config_utils import fl_config_from_json
from omegaconf import OmegaConf


json_config = {
    "trainer": {
        "_base_": "base_sync_trainer",
        # there are different types of aggegator
        # fed avg doesn't require lr, while others such as fed_avg_with_lr or fed_adam do
        "_base_": "base_sync_trainer",
        "server": {
          "_base_": "base_sync_server",
          "server_optimizer": {
            "_base_": "base_fed_avg",
          },
          # type of user selection sampling
          "active_user_selector": {"_base_": "base_uniformly_random_active_user_selector"},
        },
        "client": {
            # number of client's local epoch
            "epochs": 1,
            "optimizer": {
                "_base_": "base_optimizer_sgd",
                # client's local learning rate
                "lr": 1,
                # client's local momentum
                "momentum": 0.0,
            },
        },
        # number of users per round for aggregation
        "users_per_round": 5,
        # total number of global epochs
        # total #rounds = ceil(total_users / users_per_round) * epochs
        "epochs": 1,
        # frequentcy of reporting train metrics
        "train_metrics_reported_per_epoch": 100,
        # frequency of evaluation per epoch
        "eval_epoch_frequency": 1,
        "do_eval": True,
        # should we report train metrics after global aggregation
        "report_train_metrics_after_aggregation": True,
    }
}
cfg = fl_config_from_json(json_config)

### 5. Training

Finally, putting all the above together, to launch the FL training flow we

1. Build the data provider,
2. Create an FL model,
3. Create a metric reporter,
4. Instantiate the trainer,
5. Launch training,
6. Test the trained model.

In [None]:
# 1. Build the data provider.
data_provider = build_data_provider(
    local_batch_size=LOCAL_BATCH_SIZE,
    examples_per_user=EXAMPLES_PER_USER
)

In [None]:
# 2. We already defined the FL model earlier.
global_model

We can get the `nn.Module` by doing this

In [None]:
global_model.fl_get_module()

In [None]:
from flsim.interfaces.metrics_reporter import Channel

# 3. Create a metric reporter.
metrics_reporter = MetricsReporter([Channel.TENSORBOARD, Channel.STDOUT])

In [None]:
from hydra.utils import instantiate

# 4. Instantiate the trainer.
trainer_config = cfg.trainer
trainer = instantiate(trainer_config, model=global_model, cuda_enabled=cuda_enabled)   

We run FL training given the above JSON config and utilize `eval_score` to store the evaluation metrics.

In [None]:
# 5. Launch FL training.
final_model, eval_score = trainer.train(
    data_provider=data_provider,
    metric_reporter=metrics_reporter,
    num_total_users=data_provider.num_users(),
    distributed_world_size=1
)

After training finishes, we evaluate the model and report the test set accuracy before finishing this tutorial.

In [None]:
# 6. We can now test our model.
trainer.test(
    data_iter=data_provider.test_data(),
    metric_reporter=MetricsReporter([Channel.STDOUT]),
)


## Summary

In this tutorial, we first showed how to get the data. We then built a data provider by sharding the data to simulate multiple client devices, each with their own data, and split each client's data into batches. For our model, we defined a simple CNN, wrapped it with a model compatible with FL training, and moved it to GPU. Lastly, we defined hyperparameters and kicked off training.

### Additional resources

- Kairouz et al. (2021): [Advances and Open Problems in Federated Learning](https://arxiv.org/pdf/1912.04977.pdf). As the title suggests, an in-depth overview of advances and open problems in FL.

