# Federated Learning with SynapseMNIST3D Dataset
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/fl-tutorial/blob/gh-pages/tutorial_notebooks/Tutorial_2_Medmnist_3D_exercises.ipynb)

### Dependencies and Packages
Let's get these out of the way first.

In [None]:
!python -m pip install -U pip
!python -m pip install tqdm torch torchvision medmnist acsconv matplotlib

In [None]:
# To access example workspaces and director/envoy scripts
!rm -rf openfl
!git clone -b miccai_fl_tutorial https://github.com/intel/openfl.git
!cd openfl && python -m pip install .

### Hacks

A few duct-tape fixes to allow us to 1-click execute.

In [None]:
import os
import logging

# Better CPU Utilization
os.environ['OMP_NUM_THREADS'] = str(int(os.cpu_count() // 2))

# Logging fix for Google Colab
log = logging.getLogger()
log.setLevel(logging.INFO)

# Switch to the workspace directory
tutorial_dir = os.path.abspath(
    'openfl/openfl-tutorials/interactive_api/PyTorch_MedMNIST_3D')
os.chdir(tutorial_dir)

### Imports

In [None]:
import numpy as np
from tqdm import tqdm
from pprint import pprint

import torch
import medmnist

print('PyTorch', torch.__version__)
print('MedMNIST', medmnist.__version__)

### Familiarize yourself with the Dataset

MedMNIST is a large-scale MNIST-like collection of standardized biomedical images, including 12 datasets for 2D and 6 datasets for 3D. MedMNIST is designed to perform classification on lightweight 2D and 3D images with various data scales (from 100 to 100,000) and diverse tasks (binary/multi-class, ordinal regression and multi-label).

![Datasets in MedMNIST](https://raw.githubusercontent.com/MedMNIST/MedMNIST/main/assets/medmnistv2.jpg)

Source: https://github.com/MedMNIST/MedMNIST

Jiancheng Yang, Rui Shi, Donglai Wei, Zequan Liu, Lin Zhao, Bilian Ke, Hanspeter Pfister, Bingbing Ni. "MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification". arXiv preprint arXiv:2110.14795, 2021.

### Dataset Configuration

In [None]:
# Train/test options
NUM_EPOCHS = 3
BATCH_SIZE = 16
DEVICE = 'cpu'

# Dataset
DATASET_NAME = 'synapsemnist3d'
DATASET_PATH = './data'
ds_info = medmnist.INFO[DATASET_NAME]
pprint(ds_info)

### Familiarize with the Dataset

Let's use some plotting tools here.

In [None]:
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader

from envoy.medmnist_shard_descriptor import MedMNISTShardDescriptor

# Download raw numpy dataset
sd = MedMNISTShardDescriptor(datapath=DATASET_PATH, dataname=DATASET_NAME)
(x_train, y_train), (x_test, y_test) = sd.load_data()

# Visualize a sample
sample_id = 42
label2str = list(ds_info['label'].values())
volume = x_train[sample_id]
label = label2str[np.squeeze(y_train[sample_id])]

# Plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
cmap = plt.get_cmap("gray")
norm = plt.Normalize(volume.min(), volume.max())
ax.voxels(volume, facecolors=cmap(norm(volume)))
plt.title(label)
plt.show()


### Define Dataset/Dataloader Classes

We'll create a simple PyTorch-style iterator dataset that returns single `numpy` element as a `torch.Tensor`. The class used for this would be `torch.utils.data.Dataset`

We will then wrap this dataset object with a Dataloader class, that batches and shuffles the elements. Class: `torch.utils.data.DataLoader`

### Model Definition (3D CNN)

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self, in_channels, num_classes):
        super(Net, self).__init__()

        self.layer1 = nn.Sequential(nn.Conv2d(in_channels, 16, kernel_size=3),
                                    nn.BatchNorm2d(16), nn.ReLU())

        self.layer2 = nn.Sequential(nn.Conv2d(16, 16, kernel_size=3),
                                    nn.BatchNorm2d(16), nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer3 = nn.Sequential(nn.Conv2d(16, 64, kernel_size=3),
                                    nn.BatchNorm2d(64), nn.ReLU())

        self.layer4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3),
                                    nn.BatchNorm2d(64), nn.ReLU())

        self.layer5 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),
                                    nn.BatchNorm2d(64), nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4 * 4, 128), 
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(128, num_classes))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Go Federated
### Imports

In [None]:
import os
import time
import yaml
from typing import Dict, List, Union

### Configure `Director`

This is the entity that orchestrates the tasks and aggregation of models from participants. Below cells are to configure the `yaml` and start the `Director` service.

In [None]:
# Should be the same as defined in `director_config.yaml`
director_node_fqdn = 'localhost'
director_port = 50051

director_workspace_path = os.path.join(tutorial_dir, 'director')
director_config_file = os.path.join(director_workspace_path,'director_config.yaml')
director_logfile = os.path.join(director_workspace_path, 'director.log')

# Start director
os.system(f'cd {director_workspace_path};'
          f'fx director start --disable-tls -c {director_config_file} '
          f'>{director_logfile} &')
!sleep 5 && tail -n5 $director_logfile

### Configure `Envoys`

`Envoy`, for sake of simplicity, can be thought of as collaborators. Technically, `Envoy` defines the dataloading interface for each participant and runs python code (called a `task`) that it receives via this notebook.

We create as many config files as number of participants that we intend to simulate here.

In [None]:
def generate_envoy_configs(
        config: dict,
        n_cols: int,
        datapath: str,
        dataname: str,
        save_path: str) -> list:
    
    config_paths = list()
    for i in range(1, n_cols+1):
        path = os.path.abspath(os.path.join(save_path, f'{i}_envoy_config.yaml'))
        config['shard_descriptor']['params']['datapath'] = datapath
        config['shard_descriptor']['params']['dataname'] = dataname    
        config['shard_descriptor']['params']['rank_worldsize'] = f'{i},{n_cols}'
        with open(path, 'w') as f:
            yaml.safe_dump(config, f)
        config_paths.append(path)
    return config_paths

### Generate configs and start `Envoys`

In [None]:
# Read the original envoy config file content
original_config_path = os.path.join(tutorial_dir, 'envoy', 'envoy_config.yaml')
with open(original_config_path, 'r') as f:
    original_config = yaml.safe_load(f)

# Generate configs for as many envoys
config_paths = generate_envoy_configs(original_config,
                                      n_cols=2,
                                      datapath=DATASET_PATH,
                                      dataname=DATASET_NAME,
                                      save_path=os.path.dirname(original_config_path))
# Start envoys in a loop
cwd = os.getcwd()
for i, path in enumerate(config_paths):
    print(f'Starting Envoy {i+1}')
    os.chdir(os.path.dirname(path))

    # Wait until envoy loads dataset
    os.system(f'fx envoy start -n env_{i+1} --disable-tls '
                f'--envoy-config-path {path} -dh {director_node_fqdn} -dp {director_port} '
                f'>env_{i+1}.log 2>&1 &')
    !grep -q "MedMNIST data was loaded" <( tail -f env_{i+1}.log )
    
    os.chdir(cwd)

### Connect this Notebook to the Infrastructure

This is where you take the seat of a Data Scientist, who bears control over the `model`, `train()`, `validate()` and other logic that `Director` and `Envoy` help you execute across participants.

In [None]:
from openfl.interface.interactive_api.federation import Federation

# Create a federation
federation = None

# Wait till all envoys publish their shard registry.
pprint(federation.get_shard_registry())

### Ingredients of a Federated Learning Experiment in OpenFL

* `DataInterface`: This class defines the dataloading primitives for OpenFL. We'll reuse some of our previous logic.
* `ModelInterface`: Registers model graph and optimizer; serializes them and sends them to collaborator nodes.
* `TaskInterface`: Registers the python methods that constitute each task like `training`, `validation` etc.

In [None]:
from openfl.interface.interactive_api.experiment import DataInterface

# A fix to access the following module
os.chdir(os.path.join(tutorial_dir, 'workspace'))
from wspace_utils.utils import Transform3D

class TransformDataset(Dataset):
    """Applies transforms to each element of the Dataset"""

    def __init__(self, dataset, transform=None, target_transform=None):
        """Initializes Dataset"""
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        """Returns length of the dataset"""
        return len(self.dataset)

    def __getitem__(self, index):
        """Returns img, label by index, with transforms if any"""
        img, label = self.dataset[index]
        img = np.stack([img/255.], axis=0)
        
        if self.target_transform:
            label = self.target_transform(label)
        
        if self.transform:
            img = self.transform(img)

        return img, label

# Transforms
shape_transform = False
train_transform = Transform3D(mul='random') if shape_transform else Transform3D()
eval_transform = Transform3D(mul='0.5') if shape_transform else Transform3D()

class MedMnistFedDataset(DataInterface):

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

        self.train_set = TransformDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=train_transform)

        self.valid_set = TransformDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=eval_transform)

    def get_train_loader(self, **kwargs):
        """Output of this method will be provided to tasks with optimizer in contract"""
        return DataLoader(self.train_set,
                          num_workers=8,
                          batch_size=self.kwargs['train_bs'],
                          shuffle=True)

    def get_valid_loader(self, **kwargs):
        """Output of this method will be provided to tasks without optimizer in contract"""
        return DataLoader(self.valid_set,
                          num_workers=8,
                          batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """Information for aggregation"""
        return len(self.train_set)

    def get_valid_data_size(self):
        """Information for aggregation"""
        return len(self.valid_set)


In [None]:
fed_dataset = MedMnistFedDataset(train_bs=BATCH_SIZE, valid_bs=BATCH_SIZE)

### `ModelInterface`

In [None]:
from openfl.interface.interactive_api.experiment import ModelInterface

from acsconv.converters import Conv3dConverter
from wspace_utils.utils import model_to_syncbn

def get_3d_cnn():
  ## Fill: Instantiate a model
  model = None

  ## Fill: Convert model to 3D using Conv3dConverter
  model = None

  ## Fill: Convert all BatchNorm layers to SyncBN layers
  model = None
  return model

model = get_3d_cnn()
optimizer = None  ## Fill: Instantiate an `Adam` Optimizer
criterion = None  ## Fill: Instantiate a CrossEntropyLoss function

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model.model,
                    optimizer=optimizer,
                    framework_plugin=framework_adapter)


### `TaskInterface`
We register our tasks with a `TaskInterface` class.
OpenFL decides which model is the best based on an *increasing* metric.

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface

# Task interface currently supports only standalone functions.
TI = TaskInterface()
extra_args = {'criterion': criterion}

# Train task
@TI.add_kwargs(**extra_args)
@TI.register_fl_task(model='model', data_loader='data_loader', device='device', optimizer='optimizer')
def train(model: nn.Module, 
          data_loader: torch.utils.data.DataLoader,
          device: str,
          optimizer: torch.optim.Optimizer, 
          criterion: nn.Module) -> dict:
    """Trains `model` for 1 epoch on `train_loader`

    Args:
        model (nn.Module): PyTorch Model.
        dataloader (torch.utils.data.DataLoader): Training Dataloader.
        optimizer (torch.optim.Optimizer): Optimizer instance.
        criterion (nn.Module): Loss function instance.

    Returns:
        dict: `acc` and `loss` metrics over the dataloader
    """
    ## Set the model to training mode
    

    ## Initialize counters for accuracy/loss
    losses = []
    correct = 0
    total = 0

    ## Create a `for` loop that iterates over the dataloader
    for x, y in tqdm(data_loader, desc='training'):
        ## Push `x` and `y` tensors to the device
        
        ## Squeeze and Convert `y` to a `long` tensor

        ## Clear optimizer gradients
        
        ## Forward pass `x` through the model to get `preds`
        preds = None
        
        ## Calculate `loss` using the criterion
        loss = None

        ## Backpropagate `loss` to compute gradients
        
        ## Apply gradients with step
        
        ## Record metrics
        losses.append(loss.item())
        correct += torch.sum(preds.max(1)[1] == y).item()
        total += y.size(0)

    #############################################################
    return {
        'train_acc': np.round(correct/total, 3),
        'train_loss': np.round(np.mean(losses), 3),
    }

@TI.add_kwargs(**extra_args)
@TI.register_fl_task(model='model', data_loader='data_loader', device='device')
def validate(model: nn.Module, 
             data_loader: torch.utils.data.DataLoader,
             device: str,
             criterion: nn.Module) -> dict:
    """Computes `acc` and `loss` of the `model` on `val_loader`

    Args:
        model (nn.Module): PyTorch Model.
        data_loader (torch.utils.data.DataLoader): Validation Dataloader.
        device (str): 'cpu' or 'cuda'
        criterion (nn.Module): Loss function instance.

    Returns:
        dict: `acc` and `loss` metrics over the dataloader
    """
    ## Set the model to evaluation mode


    ## Initialize counters for accuracy/loss
    losses = []
    correct = 0
    total = 0

    ## Define a scope that disables gradient calculation
    with torch.no_grad():
        ## Create a `for` loop that iterates over the dataloader
        for x, y in tqdm(data_loader, desc='validating'):
            ## Push `x` and `y` tensors to the device
            
            ## Squeeze and Convert `y` to a `long` tensor
            
            ## Forward pass `x` through the model to get `preds`
            preds = None
            
            ## Calculate `loss` using the criterion
            loss = None
            
            ## Record metrics
            losses.append(loss.item())
            correct += torch.sum(preds.max(1)[1] == y).item()
            total += y.size(0)

    #############################################################

    return {
        'val_acc': np.round(correct/total, 3),
        'val_loss': np.round(np.mean(losses), 3),
    }

### Run the Experiment

In [None]:
from openfl.interface.interactive_api.experiment import FLExperiment

## Create a unique FL Experiment under this infrastructure:{Notebook, Director, Envoy}
fl_experiment = None
fl_experiment.start(model_provider=MI,
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=3,
                    device_assignment_policy='CUDA_PREFERRED')

# This method streams logs from the director, and also saves logs in the tensorboard format (by default)
fl_experiment.stream_metrics()

## Cleanup

In [None]:
# To stop all services run
!pkill fx
[os.remove(path) for path in config_paths]