# Federated PyTorch 3dUNET Tutorial 

In [1]:
# Install dependencies if not already installed
!pip install torch



In [2]:
import openfl.native as fx

# Setup default workspace, logging, etc. Install additional requirements
fx.init('torch_3dunet_brats')

Creating Workspace Directories
Creating Workspace Templates
Successfully installed packages from /home/radionov/.local/workspace/requirements.txt.

New workspace directory structure:
workspace
├── hyper-kvasir-segmented-images.zip
├── .workspace
├── data
│   ├── segmented-images
│   │   ├── images
│   │   ├── bounding-boxes.json
│   │   ├── license.txt
│   │   └── masks
│   ├── MICCAI_BraTS2020_TrainingData
│   │   ├── BraTS20_Training_210
│   │   ├── BraTS20_Training_023
│   │   ├── BraTS20_Training_101
│   │   ├── BraTS20_Training_227
│   │   ├── BraTS20_Training_097
│   │   ├── BraTS20_Training_046
│   │   ├── BraTS20_Training_038
│   │   ├── BraTS20_Training_296
│   │   ├── BraTS20_Training_302
│   │   ├── BraTS20_Training_211
│   │   ├── BraTS20_Training_040
│   │   ├── BraTS20_Training_099
│   │   ├── BraTS20_Training_140
│   │   ├── BraTS20_Training_272
│   │   ├── BraTS20_Training_027
│   │   ├── BraTS20_Training_098
│   │   ├── BraTS20_Training_356
│   │   ├── BraTS20_Training

│   │   ├── BraTS20_Training_133
│   │   ├── BraTS20_Training_062
│   │   ├── BraTS20_Training_201
│   │   ├── BraTS20_Training_228
│   │   ├── BraTS20_Training_295
│   │   ├── BraTS20_Training_214
│   │   ├── BraTS20_Training_122
│   │   ├── BraTS20_Training_030
│   │   ├── BraTS20_Training_320
│   │   ├── BraTS20_Training_222
│   │   ├── BraTS20_Training_304
│   │   ├── BraTS20_Training_103
│   │   ├── BraTS20_Training_032
│   │   ├── BraTS20_Training_330
│   │   ├── BraTS20_Training_060
│   │   ├── BraTS20_Training_276
│   │   ├── BraTS20_Training_364
│   │   ├── BraTS20_Training_029
│   │   ├── BraTS20_Training_034
│   │   ├── BraTS20_Training_166
│   │   ├── BraTS20_Training_205
│   │   ├── BraTS20_Training_093
│   │   ├── BraTS20_Training_041
│   │   ├── BraTS20_Training_002
│   │   ├── BraTS20_Training_188
│   │   ├── BraTS20_Training_213
│   │   ├── BraTS20_Training_322
│   │   ├── BraTS20_Training_033
│   │   ├── BraTS20_Training_335
│   │   ├── name_mapping.csv
│   │   ├── Br

In [3]:
import os
import json
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
from skimage.transform import resize
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

from openfl.federated import FederatedModel, FederatedDataSet
from openfl.utilities import TensorKey

import warnings
warnings.simplefilter("ignore")

Path to brats dataset

In [4]:
BRATS_PATH = './data/MICCAI_BraTS2020_TrainingData/'

Lets define dataset, which will get data_list brain tumor images for one collaborator (train or validation).

In [14]:
class BraTSDataset():
    """
    This dataset contains brain tumor 3d images for one collaborator train or val.
    Args:
        data_list: list of image paths
    """

    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        images = []
        for i in range(1, 5):
            img = nib.load(self.data_list[index]['image{}'.format(i)])
            img = np.asanyarray(img.dataobj)
            img = self.resize(img, (160, 160, 128))
            img = self.normalize(img)
            images.append(img)
        img = np.stack(images)
        img = img.astype(np.float32)

        mask = nib.load(self.data_list[index]['label'])
        mask = np.asanyarray(mask.dataobj)
        mask = self.resize(mask, (160, 160, 128)).astype(np.uint8)
        mask = self.classify(mask)
        return (img, mask)

    def normalize(self, data):
        data_min = np.min(data)
        return (data - data_min) / (np.max(data) - data_min)

    def resize(self, data, sizes):
        data = resize(data, sizes, mode='edge',
                      anti_aliasing=False,
                      anti_aliasing_sigma=None,
                      preserve_range=True,
                      order=0)
        return data

    def classify(self, inputs):
        result = []
        # merge label 2 and label 3 to construct TC
        result.append(np.logical_or(inputs == 2, inputs == 3))
        # merge labels 1, 2 and 3 to construct WT
        result.append(
            np.logical_or(
                np.logical_or(inputs == 2, inputs == 3), inputs == 1
            )
        )
        # label 2 is ET
        result.append(inputs == 2)
        return np.stack(result, axis=0).astype(np.float32)

Wrapper, to define generate_*_list interfaces. Further it will be replaced by another function to pass test data list

In [15]:
class FederatedDataSetWrapper(FederatedDataSet):
    def __init__(self, *args, **kwargs):

        self.train_list = self.generate_train_list(*args, **kwargs)
        self.val_list = self.generate_val_list(*args, **kwargs)
        super().__init__([], [], [], [], 1, **kwargs)

    def generate_train_list(self, *args, **kwargs):
        raise NotImplementedError

    def generate_val_list(self, *args, **kwargs):
        raise NotImplementedError

Here we redefine `FederatedDataSet` methods, if we don't want to use default batch generator from `FederatedDataSet`.<br> Also we should override `generate_train_list` and `generate_val_list` methods.<br> We should use `self.train_list` and `self.val_list` form `FederatedDataSetWrapper` as data paths, because such pipeline will be use with test data. This fields will be initialized after calling `FederatedDataSetWrapper` constructor.

In [16]:
class BraTSFederatedDataset(FederatedDataSetWrapper):
    def __init__(self, collaborator_count=1, collaborator_num=0, batch_size=1, data_list=[], **kwargs):
        """Instantiate the federated data object
        Args:
            collaborator_count: total number of collaborators
            collaborator_num: number of current collaborator
            batch_size:  the batch size of the data loader
            data_list: general list of all image paths, in current implementation 
                it should be created once so that each colaborator gets its own data
            **kwargs: additional arguments, passed to super init
        """
        self.data_list = data_list
        # Call super().__init__ to call generate_*_list methods,
        # so self.train_list and val_list fields will be initialized
        # You should use this fields as data paths, because such pipeline will be use with test data.
        super().__init__(collaborator_count, collaborator_num, num_classes=2, **kwargs)

        self.batch_size = batch_size

        self.training_set = BraTSDataset(self.train_list)
        self.valid_set = BraTSDataset(self.val_list)

        self.train_loader = self.get_train_loader()
        self.val_loader = self.get_valid_loader()

    def generate_name_list(self, collaborator_count, collaborator_num, is_validation):
        if self.data_list == []:
            data_dir = BRATS_PATH
            self.data_list = [
                {
                    'image1': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_flair.nii.gz',
                    'image2': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t1ce.nii.gz',
                    'image3': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t1.nii.gz',
                    'image4': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t2.nii.gz',
                    'label': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_seg.nii.gz'
                } for i in range(1001, 1370)]
            random.shuffle(self.data_list)

        # split all data for current collaborator
        data = self.data_list[collaborator_num:: collaborator_count]
        assert(len(data) > 7)
        validation_size = len(data) // 7
        if is_validation:
            data = data[-validation_size:]
        else:
            data = data[: -validation_size]
        return data

    # Override--------------------------------------------------------------------------
    def generate_train_list(self, collaborator_count, collaborator_num, *args, **kwargs):
        return self.generate_name_list(collaborator_count, collaborator_num, False)

    def generate_val_list(self, collaborator_count, collaborator_num, *args, **kwargs):
        return self.generate_name_list(collaborator_count, collaborator_num, True)
    # -----------------------------------------------------------------------------------

    def get_valid_loader(self, num_batches=None):
        return DataLoader(self.valid_set, num_workers=2, batch_size=self.batch_size)

    def get_train_loader(self, num_batches=None):
        return DataLoader(
            self.training_set, num_workers=2, batch_size=self.batch_size, shuffle=True
        )

    def get_train_data_size(self):
        return len(self.training_set)

    def get_valid_data_size(self):
        return len(self.valid_set)

    def get_feature_shape(self):
        return self.valid_set[0][0].shape

    def split(self, collaborator_count, shuffle=True, equally=True):
        return [
            BraTSFederatedDataset(collaborator_count,
                                  collaborator_num, self.batch_size, self.data_list)
            for collaborator_num in range(collaborator_count)
        ]

Our Unet model. Define `validate` method, to use special metric (default metric - accuracy). 

In [17]:
def soft_dice_loss(output, target):
    num = target.size(0)
    m1 = output.view(num, -1)
    m2 = target.view(num, -1)
    intersection = m1 * m2
    score = 2.0 * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
    score = 1 - score.sum() / num
    return score


def soft_dice_coef(output, target):
    num = target.size(0)
    m1 = output.view(num, -1)
    m2 = target.view(num, -1)
    intersection = m1 * m2
    score = 2.0 * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
    return score.sum()


class DoubleConv(nn.Module):
    """(Conv3D -> BN -> ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels,
                      kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels,
                      kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.encoder(x)


class Up(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()

        if trilinear:
            self.up = nn.Upsample(
                scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(
                in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY //
                        2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet3d(nn.Module):
    def __init__(self, in_channels=4, n_classes=3, n_channels=10):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 8 * n_channels)

        self.dec1 = Up(16 * n_channels, 4 * n_channels)
        self.dec2 = Up(8 * n_channels, 2 * n_channels)
        self.dec3 = Up(4 * n_channels, n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)
        mask = torch.sigmoid(mask)

        return mask

    def validate(
        self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
    ):
        """ Validate. Redifine function from PyTorchTaskRunner, to use our validation"""
        self.rebuild_model(round_num, input_tensor_dict, validation=True)
        loader = self.data_loader.get_valid_loader()
        if use_tqdm:
            loader = tqdm.tqdm(loader, desc="validate")
    # -------------Usual validation code---------------------------------------------------------------------------
        self.eval()
        self.to(self.device)
        metric = 0.0
        sample_num = 0

        with torch.no_grad():
            for val_inputs, val_labels in loader:
                val_inputs = val_inputs.to(self.device)
                val_labels = val_labels.to(self.device)
                val_outputs = self(val_inputs)
                val_outputs = (val_outputs >= 0.5).float()
                value = soft_dice_coef(val_outputs, val_labels)
                sample_num += val_labels.shape[0]
                metric += value.cpu().numpy()

            metric = metric / sample_num
    # --------------------------------------------------------------------------

        origin = col_name
        suffix = "validate"
        if kwargs["apply"] == "local":
            suffix += "_local"
        else:
            suffix += "_agg"
        tags = ("metric", suffix)
        output_tensor_dict = {
            TensorKey("dice_coef", origin, round_num, True, tags): np.array(
                metric
            )
        }
        return output_tensor_dict, {}


def optimizer(x): return torch.optim.Adam(
    x, 5e-4)

Create BraTSFederatedDataset, federated datasets for each collaborator will be created in `split()` method of this object

In [18]:
fl_data = BraTSFederatedDataset(batch_size=3)

The `FederatedModel` object is a wrapper around your Keras, Tensorflow or PyTorch model that makes it compatible with OpenFL. It provides built-in federated training function which will be used while training. Using its `setup` function, collaborator models and datasets can be automatically obtained for the experiment.

In [19]:
fl_model = FederatedModel(build_model=UNet3d, optimizer=optimizer,
                          loss_fn=soft_dice_loss, data_loader=fl_data)

In [20]:
collaborator_models = fl_model.setup(num_collaborators=2)
collaborators = {'one': collaborator_models[0], 'two': collaborator_models[1]}

We can see the current FL plan values by running the `fx.get_plan()` function

In [21]:
# Get the current values of the FL plan. Each of these can be overridden
print(json.dumps(fx.get_plan(), indent=4, sort_keys=True))

{
    "aggregator.settings.best_state_path": "save/torch_3dunet_brats_best.pbuf",
    "aggregator.settings.db_store_rounds": 1,
    "aggregator.settings.init_state_path": "save/torch_3dunet_brats_init.pbuf",
    "aggregator.settings.last_state_path": "save/torch_3dunet_brats_last.pbuf",
    "aggregator.settings.rounds_to_train": 40,
    "aggregator.template": "openfl.component.Aggregator",
    "assigner.settings.task_groups": [
        {
            "name": "train_and_validate",
            "percentage": 1.0,
            "tasks": [
                "aggregated_model_validation",
                "train",
                "locally_tuned_model_validation"
            ]
        }
    ],
    "assigner.template": "openfl.component.RandomGroupedAssigner",
    "collaborator.settings.db_store_rounds": 1,
    "collaborator.settings.delta_updates": false,
    "collaborator.settings.opt_treatment": "RESET",
    "collaborator.template": "openfl.component.Collaborator",
    "data_loader.settings.batch

You can see common plan with all options. Lets concentrated on options directly related with federation. 

`aggregator.settings.db_store_rounds` - rounds to store model weights.

`aggregator.settings.rounds_to_train` - number of training rounds.

`collaborator.settings.delta_updates` - Sent only model delta (or full model if false)

`collaborator.settings.opt_treatment` - The optimizer state treatment:
    """
    RESET tells each collaborator to reset the optimizer state at the beginning
    of each round.
    CONTINUE_LOCAL tells each collaborator to continue with the local optimizer
    state from the previous round.
    CONTINUE_GLOBAL tells each collaborator to continue with the federally
    averaged optimizer state from the previous round.
    """
    
`tasks.aggregated_model_validation.aggregation_type` - aggregation function for current task. It maybe name of np function. Default - "weighted_average", which related to np.average with weights

We can override this option by definnig overriding config and pass it to run_expirement function:

In [24]:
# To override  aggregator.settings.db_store_rounds
# number of rounds
override_config = {
        'aggregator.settings.db_store_rounds': 1,
        'aggregator.settings.rounds_to_train': 40,
        'collaborator.settings.delta_updates': False,
        'collaborator.settings.opt_treatment': "RESET",
        'tasks.aggregated_model_validation.aggregation_type': ["weighted_average"]
}

In [None]:
# Run experiment, return trained FederatedModel
final_fl_model = fx.run_experiment(
    collaborators, override_config=override_config)

Lets validate final model on common validation dataset.

In [14]:
model = final_fl_model.model
model.eval()
device = final_fl_model.runner.device
model.to(device)
metric = 0.0
sample_num = 0
with torch.no_grad():
    for collaborator in collaborator_models:
        loader = collaborator.runner.data_loader.get_valid_loader()
        for val_inputs, val_labels in tqdm(loader):
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.to(device)
            val_outputs = model(val_inputs)
            val_outputs = (val_outputs >= 0.5).float()
            value = soft_dice_coef(val_outputs, val_labels)
            sample_num += val_labels.shape[0]
            metric += value.cpu().numpy()

metric = metric / sample_num

100%|██████████| 9/9 [00:24<00:00,  2.73s/it]
100%|██████████| 9/9 [00:24<00:00,  2.75s/it]


In [15]:
# Metric for final model on whole validation dataset
metric

0.6985069100673382

Inference final model on test data. We should replace `generate_*_list` methods to pass test data to user datasets

In [28]:
# Checking

def generate_name_list(collaborator_count, collaborator_num, is_validation):
    data_dir = './data/MICCAI_BraTS2020_TrainingData/'
    data = [
        {
            'image1': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_flair.nii.gz',
            'image2': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t1ce.nii.gz',
            'image3': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t1.nii.gz',
            'image4': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t2.nii.gz',
            'label': data_dir + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_seg.nii.gz'
        } for i in range(1001, 1370)]
    data = data[collaborator_num:: collaborator_count]
    assert(len(data) > 8)
    validation_size = len(data) // 8
    if is_validation:
        data = data[-validation_size:]
    else:
        data = data[: -validation_size]
    return data


def OUR_generate_train_list(self,  *args, **kwargs):
    return generate_name_list(2, 1, False)

# Take in attention, we create only 4 object in valid dataset


def OUR_generate_val_list(self,  *args, **kwargs):
    return generate_name_list(2, 1, True)[:4]  # create only 4 images


# Replace functions:
BraTSFederatedDataset.generate_train_list = OUR_generate_train_list
BraTSFederatedDataset.generate_val_list = OUR_generate_val_list

# Usual training process:
fl_data2 = BraTSFederatedDataset(batch_size=6)
fl_model2 = FederatedModel(build_model=UNet3d, optimizer=optimizer,
                           loss_fn=loss_function, data_loader=fl_data2)
collaborator_models2 = fl_model2.setup(num_collaborators=2)
collaborators2 = {
    'one': collaborator_models2[0], 'two': collaborator_models2[1]}

In [29]:
# Just show how new collaborators work:
for collaborator in collaborator_models2:
    loader = collaborator.runner.data_loader.get_valid_loader()
    for val_inputs, val_labels in tqdm(loader):
        print(val_inputs.shape)

100%|██████████| 1/1 [00:06<00:00,  6.98s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([4, 4, 160, 160, 128])


100%|██████████| 1/1 [00:07<00:00,  7.06s/it]

torch.Size([4, 4, 160, 160, 128])





In [None]:
# We get 4 - size batch, so this is our new data