# Federated PyTorch 3dUNET Tutorial
## Using low-level Python API

In [None]:
# Install dependencies if not already installed
!pip install torchvision
!pip install torch
!pip install scikit-image
!pip install dill
!pip install nibabel 
!pip install cloudpickle

### Describe the model and optimizer

In [2]:
import torch.nn as nn
import torch.optim as optim
import os
from hashlib import sha384
import PIL
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tsf
from skimage import io
import numpy as np
import random
import nibabel as nib
from skimage.transform import resize
import torch.nn.functional as F

from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
from openfl.interface.interactive_api.federation import Federation

In [3]:
"""
3dUNet model definition
"""
import torch
from layers import soft_dice_loss, soft_dice_coef, DoubleConv, Down, Up, Out

class UNet3d(nn.Module):
    def __init__(self, in_channels=4, n_classes=3, n_channels=10):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 8 * n_channels)

        self.dec1 = Up(16 * n_channels, 4 * n_channels)
        self.dec2 = Up(8 * n_channels, 2 * n_channels)
        self.dec3 = Up(4 * n_channels, n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)
        mask = torch.sigmoid(mask)

        return mask
    
model_unet = UNet3d()

In [4]:
optimizer_adam = optim.Adam(model_unet.parameters(), lr=1e-4)

### Prepare data

We ask user to keep all the test data in `data/` folder under the workspace as it will not be sent to collaborators

Path to brats dataset

In [5]:
#!cp ~/brain_tumor/MICCAI_BraTS2020_TrainingData ./data/ -r
BRATS_PATH = './data/MICCAI_BraTS2020_TrainingData/'

In [6]:
class BraTSDataset():
    """
    This dataset contains brain tumor 3d images for one collaborator train or val.
    Args:
        data_list: list of image paths
    """

    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        images = []
        for i in range(1, 5):
            img = nib.load(self.data_list[index]['image{}'.format(i)])
            img = np.asanyarray(img.dataobj)
            img = self.resize(img, (160, 160, 128))
            img = self.normalize(img)
            images.append(img)
        img = np.stack(images)
        img = img.astype(np.float32)

        mask = nib.load(self.data_list[index]['label'])
        mask = np.asanyarray(mask.dataobj)
        mask = self.resize(mask, (160, 160, 128)).astype(np.uint8)
        mask = self.classify(mask)
        return (img, mask)

    def normalize(self, data):
        data_min = np.min(data)
        return (data - data_min) / (np.max(data) - data_min)

    def resize(self, data, sizes):
        data = resize(data, sizes, mode='edge',
                      anti_aliasing=False,
                      anti_aliasing_sigma=None,
                      preserve_range=True,
                      order=0)
        return data

    def classify(self, inputs):
        result = []
        # merge label 2 and label 3 to construct TC
        result.append(np.logical_or(inputs == 2, inputs == 3))
        # merge labels 1, 2 and 3 to construct WT
        result.append(
            np.logical_or(
                np.logical_or(inputs == 2, inputs == 3), inputs == 1
            )
        )
        # label 2 is ET
        result.append(inputs == 2)
        return np.stack(result, axis=0).astype(np.float32)

### Register model

In [7]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model_unet, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_unet)

### Register dataset

In [8]:
class FedDataset(DataInterface):
    # Use class atribute to shufle once.
    data_list = [
        {
            'image1': BRATS_PATH + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_flair.nii.gz',
            'image2': BRATS_PATH + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t1ce.nii.gz',
            'image3': BRATS_PATH + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t1.nii.gz',
            'image4': BRATS_PATH + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_t2.nii.gz',
            'label': BRATS_PATH + 'BraTS20_Training_'+str(i)[1:]+'/BraTS20_Training_'+str(i)[1:]+'_seg.nii.gz'
        } for i in range(1001, 1370)]
    random.shuffle(data_list)

    def __init__(self, UserDatasetClass, **kwargs):
        self.UserDatasetClass = UserDatasetClass
        self.kwargs = kwargs
    # def __init__(self, collaborator_count=1, collaborator_num=0, batch_size=1, data_list=[], **kwargs):

    def _delayed_init(self, data_path='1,1'):
        self.rank, self.world_size = [int(part)
                                      for part in data_path.split(',')]

        self.train_list = self.generate_train_list(self.world_size, self.rank)
        self.val_list = self.generate_val_list(self.world_size, self.rank)

        self.train_set = self.UserDatasetClass(self.train_list)
        self.valid_set = self.UserDatasetClass(self.val_list)

    def generate_name_list(self, collaborator_count, collaborator_num, is_validation):
        # split all data for current collaborator
        data = getattr(self, 'data_list')[
            collaborator_num:: collaborator_count]
        data = data
        assert(len(data) > 7)
        validation_size = len(data) // 7
        if is_validation:
            data = data[-validation_size:]
        else:
            data = data[: -validation_size]
        return data

    def generate_train_list(self, collaborator_count, collaborator_num, *args, **kwargs):
        return self.generate_name_list(collaborator_count, collaborator_num, False)

    def generate_val_list(self, collaborator_count, collaborator_num, *args, **kwargs):
        return self.generate_name_list(collaborator_count, collaborator_num, True)

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True
        )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

In [9]:
fed_dataset = FedDataset(BraTSDataset, train_bs=3, valid_bs=3)

### Register tasks

In [10]:
import tqdm
import torch
TI = TaskInterface()


@TI.register_fl_task(model='unet_model', data_loader='train_loader',
                     device='device', optimizer='optimizer')
def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss, some_parameter=None):
    if not torch.cuda.is_available():
        device = 'cpu'
    train_loader = tqdm.tqdm(train_loader, desc="train")

    unet_model.train()
    unet_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)
        optimizer.zero_grad()
        output = unet_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())

    return {'train_loss': np.mean(losses), }


@TI.register_fl_task(model='unet_model', data_loader='val_loader', device='device')
def validate(unet_model, val_loader, device):
    unet_model.eval()
    unet_model.to(device)
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    metric = 0.0
    sample_num = 0

    with torch.no_grad():
        for val_inputs, val_labels in val_loader:
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.to(device)
            val_outputs = unet_model(val_inputs)
            val_outputs = (val_outputs >= 0.5).float()
            value = soft_dice_coef(val_outputs, val_labels)
            sample_num += val_labels.shape[0]
            metric += value.cpu().numpy()

    return {'dice_coef': metric / sample_num}

## Time to start a federated learning experiment

In [11]:
# Create a federation

federation = Federation(central_node_fqdn='localhost', disable_tls=True)

# First number which is a collaborators rank is also passed as a cuda device identifier
col_data_paths = {'one': '1,1'}#, 'two': '2,2'}
federation.register_collaborators(col_data_paths=col_data_paths)

In [12]:
# create an experimnet in federation
fl_experiment = FLExperiment(federation=federation)

In [None]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
# Redefine some parametrs
fl_experiment.prepare_workspace_distribution(model_provider=MI,
                                             task_keeper=TI,
                                             data_loader=fed_dataset,
                                             rounds_to_train=2,
                                             opt_treatment='CONTINUE_GLOBAL')
# This command starts the aggregator server
fl_experiment.start_experiment(model_provider=MI)

tried to remove tensor: __opt_state_needed not present in the tensor dict
tried to remove tensor: __opt_state_needed not present in the tensor dict
gRPC is running on insecure channel with TLS disabled.


## Now we validate the best model!

In [None]:
best_model = fl_experiment.get_best_model()

In [None]:
fed_dataset._delayed_init()

In [None]:
# Validating initial model
validate(initial_model, fed_dataset.get_valid_loader(), 'cpu')

In [None]:
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')