# Federated PyTorch 3dUNET Tutorial 

 We will use MONAI brats [tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/brats_segmentation_3d.ipynb) as a template

In [1]:
# Install dependencies if not already installed
!python -c "import torch" || pip install torch
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [2]:
import openfl.native as fx

# Setup default workspace, logging, etc. Install additional requirements
fx.init('torch_3dunet_brats')

Creating Workspace Directories
Creating Workspace Templates


  return torch._C._cuda_getDeviceCount() > 0


Successfully installed packages from /home/maksim/.local/workspace/requirements.txt.

New workspace directory structure:
workspace
├── agg_to_col_one_signed_cert.zip
├── final_model
│   ├── assets
│   ├── saved_model.pb
│   └── variables
│       ├── variables.data-00000-of-00001
│       └── variables.index
├── code
│   ├── data_loader.py
│   ├── pt_cnn.py
│   ├── fed_unet_runner.py
│   ├── mnist_utils.py
│   ├── keras_cnn.py
│   ├── ptmnist_inmemory.py
│   ├── tfmnist_inmemory.py
│   ├── pt_unet_parts.py
│   ├── __init__.py
│   └── fed_3dunet_runner.py
├── plan
│   ├── data.yaml
│   ├── plan.yaml
│   ├── cols.yaml
│   └── defaults
│       ├── .tasks_torch.yaml.swp
│       ├── aggregator.yaml
│       ├── network.yaml
│       ├── assigner.yaml
│       ├── collaborator.yaml
│       ├── tasks_torch.yaml
│       ├── tasks_tensorflow.yaml
│       ├── tasks_keras.yaml
│       ├── tasks_fast_estimator.yaml
│       ├── data_loader.yaml
│       ├── task_runner.yaml
│       └── defaults
├── logs


In [3]:
import os
import json
from tqdm import tqdm
from hashlib import sha384
from os import path
import torch
import matplotlib.pyplot as plt
import numpy as np

from monai.data import DataLoader
from monai.data import CacheDataset
from monai.data import load_decathlon_datalist
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AsDiscrete,
    CenterSpatialCropd,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    ToTensord,
)

from openfl.federated import FederatedModel, FederatedDataSet
from openfl.utilities import TensorKey

Download BraTS dataset

In [4]:
# !wget -c --tries=0 --retry-connrefused --timeout=2 --wait=1  --continue "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar" -O brats.tar
# TAR_SHA384 = '049f8e1425d9e47a4cdabe03c5c2ff68aa01b6298a307'\
#     '304638abd9b1341f0639d015357ca315d402984bc1cffa16bbf'
# assert sha384(open('./brats.tar', 'rb').read(
#     path.getsize('./brats.tar'))).hexdigest() == TAR_SHA384
# !tar -xvf brats.tar -C ./data
# !rm ./data/Task01_BrainTumour/imagesTr/.*.nii.gz

Prepare preprocessing function (just copy it from MONAI [tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/brats_segmentation_3d.ipynb)):

In [5]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(np.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d


train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(
            keys=["image", "label"], roi_size=[128, 128, 64], random_size=False
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ]
)

val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ]
)

Lets define dataset, which will contain brain tumor images for one collaborator (train or validation).

In [6]:
class BraTSDataset(CacheDataset):
    """
    This dataset contains brain tumor 3d images for one collaborator train or val.
    Args:
        collaborator_count: total number of collaborators
        collaborator_num: number of current collaborator
        is_validation: validation option
        transform: transform sequence
    """

    def __init__(self, data_list, transform):
        CacheDataset.__init__(self, data_list, transform, cache_num=1, num_workers=4)

    # define getitem to get only input and target tensors
    def __getitem__(self, index):
        tmp = super().__getitem__(index)
        return (tmp['image'], tmp['label'])

Here we redefine `FederatedDataSet` methods, if we don't want to use default batch generator from `FederatedDataSet`

In [7]:
class FederatedDataSetWrapper(FederatedDataSet):
    def __init__(self,*args, **kwargs):
        
        self.train_list = self.generate_train_list( *args, **kwargs)
        self.val_list = self.generate_val_list( *args, **kwargs)
        super().__init__([], [], [], [], 1, **kwargs)
        
    def generate_train_list(self, *args, **kwargs):
        raise NotImplementedError
        
    def generate_val_list(self, *args, **kwargs):
        raise NotImplementedError

In [18]:
class BraTSFederatedDataset(FederatedDataSetWrapper):
    def __init__(self, collaborator_count=1, collaborator_num=0, batch_size=1, **kwargs):
        """Instantiate the federated data object
        Args:
            collaborator_count: total number of collaborators
            collaborator_num: number of current collaborator
            batch_size:  the batch size of the data loader
            **kwargs: additional arguments, passed to super init
        """
        super().__init__(num_classes=2, **kwargs)

        self.collaborator_num = int(collaborator_num)

        self.batch_size = batch_size

        self.training_set = BraTSDataset(self.train_list, transform=train_transform)
        self.valid_set = BraTSDataset(self.val_list, transform=val_transform)

        self.train_loader = self.get_train_loader()
        self.val_loader = self.get_valid_loader()
        
        
    def generate_name_list(self, collaborator_count, collaborator_num, is_validation):
        dataset_dir = './data/Task01_BrainTumour/'
        data = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, "training")
        # split all data for current collaborator
        data = data[collaborator_num:: collaborator_count]
        assert(len(data) > 8)
        validation_size = len(data) // 8
        if is_validation:
            data = data[-validation_size:]
        else:
            data = data[: -validation_size]
        return data
    
    #Override--------------------------------------------------------------------------
    def generate_train_list(self, collaborator_count, collaborator_num, *args, **kwargs):
        return self.generate_name_list(collaborator_count, collaborator_num, False)[:11]
        
    def generate_val_list(self, collaborator_count, collaborator_num, *args, **kwargs):
        return self.generate_name_list(collaborator_count, collaborator_num, True)[:5]
    #-----------------------------------------------------------------------------------
    
    def get_valid_loader(self, num_batches=None):
        return DataLoader(self.valid_set, num_workers=2, batch_size=self.batch_size)

    def get_train_loader(self, num_batches=None):
        return DataLoader(
            self.training_set, num_workers=2, batch_size=self.batch_size, shuffle=True
        )

    def get_train_data_size(self):
        return len(self.training_set)

    def get_valid_data_size(self):
        return len(self.valid_set)

    def get_feature_shape(self):
        return self.valid_set[0][0].shape

    def split(self, collaborator_count, shuffle=True, equally=True):
        return [
            BraTSFederatedDataset(collaborator_count,
                                  collaborator_num, self.batch_size)
            for collaborator_num in range(collaborator_count)
        ]

Our Unet model. Use MONAI UNet. Define validation function, to use special metric. 

In [19]:
class UnetWrapper(UNet):
    def __init__(self):
        super().__init__(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,)

    def validate(
        self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
    ):
        """ Validate. Redifine function from PyTorchTaskRunner, to use our validation"""
        self.rebuild_model(round_num, input_tensor_dict, validation=True)
        loader = self.data_loader.get_valid_loader()
        if use_tqdm:
            loader = tqdm.tqdm(loader, desc="validate")
# -------------Usual validation code---------------------------------------------------------------------------
        self.eval()
        self.to(self.device)
        dice_metric = DiceMetric(include_background=True, reduction="mean")
        post_trans = Compose(
            [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
        )
        metric_sum = 0.0
        metric_count = 0

        with torch.no_grad():
            for val_inputs, val_labels in loader:
                val_inputs = val_inputs.to(self.device)
                val_labels = val_labels.to(self.device)
                val_outputs = self(val_inputs)
                val_outputs = post_trans(val_outputs)
                # compute overall mean dice
                value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                not_nans = not_nans.item()
                metric_count += not_nans
                metric_sum += value.item() * not_nans

            metric = metric_sum / metric_count
# --------------------------------------------------------------------------

        origin = col_name
        suffix = "validate"
        if kwargs["apply"] == "local":
            suffix += "_local"
        else:
            suffix += "_agg"
        tags = ("metric", suffix)
        output_tensor_dict = {
            TensorKey("dice_coef", origin, round_num, True, tags): np.array(
                metric
            )
        }
        return output_tensor_dict, {}

Loss function and optimizer which will be passed to `FederatedModel`

In [20]:
# Wrapper, because our train_batches set (output, target) args, but DiceLoss get (input, target)

class DiceLossHeir(DiceLoss):
    __name__ = 'DiceLoss'

    def forward(self, output, target):
        return super().forward(input=output, target=target)


loss_function = DiceLossHeir(
    to_onehot_y=False, sigmoid=True, squared_pred=True)


def optimizer(x): return torch.optim.Adam(
    x, 1e-4, weight_decay=1e-5, amsgrad=True
)

Create BraTSFederatedDataset, federated datasets for collaborators will be created in `split()` method of this object

In [21]:
fl_data = BraTSFederatedDataset(batch_size=6)

100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
100%|██████████| 1/1 [00:01<00:00,  1.27s/it]


The `FederatedModel` object is a wrapper around your Keras, Tensorflow or PyTorch model that makes it compatible with OpenFL. It provides built-in federated training function which will be used while training. Using its `setup` function, collaborator models and datasets can be automatically obtained for the experiment.

In [22]:
fl_model = FederatedModel(build_model=UnetWrapper, optimizer=optimizer,
                          loss_fn=loss_function, data_loader=fl_data)

In [23]:
collaborator_models = fl_model.setup(num_collaborators=2)
collaborators = {'one': collaborator_models[0], 'two': collaborator_models[1]}

100%|██████████| 1/1 [00:01<00:00,  1.15s/it]
100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
100%|██████████| 1/1 [00:01<00:00,  1.16s/it]
100%|██████████| 1/1 [00:01<00:00,  1.17s/it]


We can see the current FL plan values by running the `fx.get_plan()` function

In [24]:
# Get the current values of the FL plan. Each of these can be overridden
print(json.dumps(fx.get_plan(), indent=4, sort_keys=True))

{
    "aggregator.settings.best_state_path": "save/torch_3dunet_brats_best.pbuf",
    "aggregator.settings.db_store_rounds": 1,
    "aggregator.settings.init_state_path": "save/torch_3dunet_brats_init.pbuf",
    "aggregator.settings.last_state_path": "save/torch_3dunet_brats_last.pbuf",
    "aggregator.settings.rounds_to_train": 1,
    "aggregator.template": "openfl.component.Aggregator",
    "assigner.settings.task_groups": [
        {
            "name": "train_and_validate",
            "percentage": 1.0,
            "tasks": [
                "aggregated_model_validation",
                "train",
                "locally_tuned_model_validation"
            ]
        }
    ],
    "assigner.template": "openfl.component.RandomGroupedAssigner",
    "collaborator.settings.db_store_rounds": 1,
    "collaborator.settings.delta_updates": false,
    "collaborator.settings.opt_treatment": "RESET",
    "collaborator.template": "openfl.component.Collaborator",
    "data_loader.settings.batch_

You can see common plan with all options. Lets concentrated on options directly related with federation. 

`aggregator.settings.db_store_rounds` - rounds to store model weights.

`aggregator.settings.rounds_to_train` - number of training rounds.

`collaborator.settings.delta_updates` - Sent only model delta (or full model if false)

`collaborator.settings.opt_treatment` - The optimizer state treatment:
    """
    RESET tells each collaborator to reset the optimizer state at the beginning
    of each round.
    CONTINUE_LOCAL tells each collaborator to continue with the local optimizer
    state from the previous round.
    CONTINUE_GLOBAL tells each collaborator to continue with the federally
    averaged optimizer state from the previous round.
    """
    
`tasks.aggregated_model_validation.aggregation_type` - aggregation function for current task. It maybe name of np function. Default - "weighted_average", which related to np.average with weights

We can override this option by definnig overriding config and pass it to run_expirement function:

In [25]:
# To override  aggregator.settings.db_store_rounds
# number of rounds
override_config = {
    'aggregator.settings.db_store_rounds': 1,
    'aggregator.settings.rounds_to_train': 1,
    'collaborator.settings.delta_updates': False,
    'collaborator.settings.opt_treatment': "RESET",
    'tasks.aggregated_model_validation.aggregation_type': ["weighted_average"]
}

In [26]:
# Run experiment, return trained FederatedModel
final_fl_model = fx.run_experiment(
    collaborators, override_config=override_config)

  data, target = pt.tensor(data).to(self.device), pt.tensor(
  target).to(self.device, dtype=pt.float32)


<openfl.federated.task.fl_model.FederatedModel at 0x7f430d05ffd0>

Lets validate final model on common validation dataset.

In [27]:
model = final_fl_model.model
model.eval()
device = final_fl_model.runner.device
model.to(device)
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_trans = Compose(
    [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
)
metric_sum = 0.0
metric_count = 0
with torch.no_grad():
    for collaborator in collaborator_models:
        loader = collaborator.runner.data_loader.get_valid_loader()
        for val_inputs, val_labels in tqdm(loader):
            print(val_inputs.shape)
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.to(device)
            val_outputs = model(val_inputs)
            val_outputs = post_trans(val_outputs)
            value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
            not_nans = not_nans.item()
            metric_count += not_nans
            metric_sum += value.item() * not_nans

metric = metric_sum / metric_count

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([5, 4, 128, 128, 64])


100%|██████████| 1/1 [00:05<00:00,  5.91s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([5, 4, 128, 128, 64])


100%|██████████| 1/1 [00:06<00:00,  6.36s/it]


In [38]:
#Checking

def generate_name_list(collaborator_count, collaborator_num, is_validation):
    dataset_dir = './data/Task01_BrainTumour/'
    data = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, "training")
    # split all data for current collaborator
    data = data[collaborator_num:: collaborator_count]
    assert(len(data) > 8)
    validation_size = len(data) // 8
    if is_validation:
        data = data[-validation_size:]
    else:
        data = data[: -validation_size]
    return data

def OUR_generate_train_list(self,  *args, **kwargs):
    return generate_name_list(2, 1, False)[:11]

#Take in attention, we create only 4 object in valid dataset 
def OUR_generate_val_list(self,  *args, **kwargs):
    return generate_name_list(2, 1, True)[:4] #create only 4 images

#Replace functions:
BraTSFederatedDataset.generate_train_list = OUR_generate_train_list
BraTSFederatedDataset.generate_val_list = OUR_generate_val_list

#Usually training process:
fl_data = BraTSFederatedDataset(batch_size=6)
fl_model = FederatedModel(build_model=UnetWrapper, optimizer=optimizer,
                          loss_fn=loss_function, data_loader=fl_data)
collaborator_models = fl_model.setup(num_collaborators=2)
collaborators = {'one': collaborator_models[0], 'two': collaborator_models[1]}


100%|██████████| 1/1 [00:01<00:00,  1.19s/it]
100%|██████████| 1/1 [00:01<00:00,  1.19s/it]


100%|██████████| 1/1 [00:01<00:00,  1.16s/it]
100%|██████████| 1/1 [00:01<00:00,  1.19s/it]
100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
100%|██████████| 1/1 [00:01<00:00,  1.16s/it]


In [39]:
#Just show how new collaborators work:
for collaborator in collaborator_models:
    loader = collaborator.runner.data_loader.get_valid_loader()
    for val_inputs, val_labels in tqdm(loader):
        print(val_inputs.shape)

100%|██████████| 1/1 [00:03<00:00,  3.64s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([4, 4, 128, 128, 64])


100%|██████████| 1/1 [00:03<00:00,  3.63s/it]

torch.Size([4, 4, 128, 128, 64])





In [None]:
#We get 4 - size batch, so this is our data