In [None]:
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy
from fedbiomed.common.constants import ErrorNumbers
from fedbiomed.common.exceptions import FedbiomedStrategyError
from fedbiomed.common.logger import logger


class MedicalFolderStrategy(DefaultStrategy):
    def __init__(self, data, modalities=['T1']):
        super().__init__(data)
        self._modalities = modalities

    def refine(self, training_replies, round_i):
        models_params = []
        weights = []

        # check that all nodes answered
        cl_answered = [val['node_id'] for val in training_replies.data()]

        answers_count = 0
        for cl in self.sample_nodes(round_i):
            if cl in cl_answered:
                answers_count += 1
            else:
                # this node did not answer
                logger.error(f'{ErrorNumbers.FB408.value} (node = {cl})')

        if len(self.sample_nodes(round_i)) != answers_count:
            if answers_count == 0:
                # none of the nodes answered
                msg = ErrorNumbers.FB407.value

            else:
                msg = ErrorNumbers.FB408.value

            logger.critical(msg)
            raise FedbiomedStrategyError(msg)

        # check that all nodes that answer could successfully train
        self._success_node_history[round_i] = []
        all_success = True
        for tr in training_replies:
            if tr['success'] is True:
                model_params = {tr['node_id']: tr['params']}
                models_params.append(model_params)
                self._success_node_history[round_i].append(tr['node_id'])
            else:
                # node did not succeed
                all_success = False
                logger.error(f'{ErrorNumbers.FB409.value} (node = {tr["node_id"]})')

        if not all_success:
            raise FedbiomedStrategyError(ErrorNumbers.FB402.value)

        # so far, everything is OK
        shapes = [val[0]["shape"]['demographics'][0] for (key, val) in self._fds.data().items()]
        total_rows = sum(shapes)
        weights = [{key: val[0]["shape"]['demographics'][0] / total_rows} for (key, val) in self._fds.data().items()]
        logger.info('Nodes that successfully reply in round ' +
            str(round_i) + ' ' +
            str(self._success_node_history[round_i]))
        return models_params, weights

In [None]:
import numpy as np
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager, MedicalFolderDataset
import torch
from torch import nn
from torch.optim import Adam, SGD
from monai.losses.dice import DiceLoss
from monai.networks.nets import UNet, SegResNet
from monai.transforms import Compose, Resize, NormalizeIntensity, AddChannel, AsDiscrete, Lambda, CenterSpatialCrop
import torch.nn.functional as F


class UNetTrainingPlan(TorchTrainingPlan):

    def init_model(self, model_args):
        model = self.Net(model_args)
        return model

    def init_optimizer(self, optimizer_args):
        tmp_args = self.model().model_arguments
        optimizer_name = tmp_args['optimizer_name'] if 'optimizer_name' in tmp_args.keys() else 'adam'
        if optimizer_name == 'adam':
            optimizer = Adam(self.model().parameters(), **optimizer_args)
        elif optimizer_name == 'sgd':
            optimizer = SGD(self.model().parameters(), **optimizer_args)
        # optimizer = Adam(self.model().parameters(), lr=optimizer_args['lr'])
        return optimizer

    def init_dependencies(self):
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from monai.transforms import (Compose, NormalizeIntensity, AddChannel, CenterSpatialCrop, "
                "AsDiscrete, Lambda)",
                "import torch.nn as nn",
                'import torch.nn.functional as F',
                "from fedbiomed.common.data import MedicalFolderDataset",
                'import numpy as np',
                'from monai.losses.dice import DiceLoss',
                'from torch.optim import AdamW, Adam, SGD',
                "from monai.networks.nets import UNet, SegResNet"]
        return deps

    class Net(nn.Module):
        # Init of UNetTrainingPlan
        def __init__(self, model_args: dict = {}):
            super().__init__()
            self.CHANNELS_DIMENSION = 1
            if model_args['network_type'] == 'unet':
                print(f"Selected model is unet")
                net = UNet(
                    spatial_dims=3,
                    in_channels=4,
                    out_channels=4,
                    channels=(30, 30 * 2, 30 * 4, 30 * 8, 30 * 16),
                    strides=(2, 2, 2, 2),
                    num_res_units=2,
                    kernel_size=3,
                    dropout=0.3,
                )
            elif model_args['network_type'] == 'segresnet':
                print(f"Selected model is segresnet")
                net = SegResNet(spatial_dims=3,
                                init_filters=16,
                                in_channels=4,
                                out_channels=4,
                                dropout_prob=0.2,
                                act=('RELU', {'inplace': True}),
                                norm=('GROUP', {'num_groups': 8}),
                                norm_name='',
                                num_groups=8,
                                use_conv_final=True,
                                blocks_down=(1, 2, 2, 4),
                                blocks_up=(1, 1, 1)
                                )
            else:
                raise ValueError(f'The mode_args dictionary contains {model_args["network_type"]}, '
                                 f'which is not an allowed value')
            self.net = net
            self.model_arguments = model_args

        def forward(self, x):
            x = self.net.forward(x)
            x = F.softmax(x, dim=self.CHANNELS_DIMENSION)
            return x

    @staticmethod
    def get_dice_loss(output, target):
        loss = DiceLoss(include_background=True, sigmoid=False)
        loss_value = loss(output, target)
        return loss_value

    @staticmethod
    def demographics_transform(demographics: dict):
        return {}

    def training_data(self, batch_size=4):
        # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        common_shape = (240, 240, 128)
        training_transform = {modality: Compose([AddChannel(), CenterSpatialCrop(common_shape), NormalizeIntensity(), ]) for
                              modality in ('T1', 'T1CE', 'FLAIR', 'T2')}
        target_transform = Compose([AddChannel(), CenterSpatialCrop(common_shape),
                                    Lambda(func=lambda x: torch.where(x == 4, 3, x)),
                                    AsDiscrete(to_onehot=4)])

        dataset = MedicalFolderDataset(
            root=self.dataset_path,
            data_modalities=['T1', 'T1CE', 'FLAIR', 'T2'],
            target_modalities='SEG',
            transform=training_transform,
            target_transform=target_transform,
            demographics_transform=UNetTrainingPlan.demographics_transform)
        loader_arguments = {'batch_size': batch_size, 'shuffle': False}
        return DataManager(dataset, **loader_arguments)

    def training_step(self, data, target):
        torch.cuda.empty_cache()
        # this function must return the loss to backward it
        img = torch.cat((data[0]['T1CE'], data[0]['T1'], data[0]['T2'], data[0]['FLAIR']), 1)
        demographics = data[1]
        output = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(output, target['SEG'])
        avg_loss = loss.mean()
        return avg_loss

    def testing_step(self, data, target):
        torch.cuda.empty_cache()
        img = torch.cat((data[0]['T1CE'], data[0]['T1'], data[0]['T2'], data[0]['FLAIR']), 1)
        demographics = data[1]
        target = target['SEG']
        prediction = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(prediction, target)
        avg_loss = loss.item().mean()  # average per batch
        return avg_loss


In [None]:
import os
import sys
import configparser
from mmap import mmap

import pytz
from datetime import datetime
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
from fedbiomed.researcher.aggregators.scaffold import Scaffold
from fedbiomed.researcher.environ import environ
from fedbiomed.common.logger import logger
from declearn.optimizer.modules import AdamModule, AdaGradModule, YogiModule
from declearn.optimizer.regularizers import FedProxRegularizer
from fedbiomed.common.optimizers.optimizer import Optimizer
from plan_segmentation import UNetTrainingPlan
import torch
import pickle
import json
import shutil
from pathlib import Path

HOME = str(Path.home())
PROJECT_DIR = f'{HOME}/fets_analysis'


def get_num_rounds(configuration: str, config_args: dict):
    num_updates = int(config_args['num_updates']) if 'num_updates' in config_args.keys() else 20
    num_epochs_local = int(config_args['num_epochs_local']) if 'num_epochs_local' in config_args.keys() else 25
    num_clients = int(config_args['num_clients']) if 'num_clients' in config_args.keys() else 23
    batch_size = int(config_args['batch_size']) if 'batch_size' in config_args.keys() else 8
    dim_training_set = 1000
    if configuration in ['LocalNode', 'centralized']:
        return 1
    else:
        return (dim_training_set // num_clients // batch_size) * num_epochs_local // num_updates


def mapcount(filename):
    f = open(filename, "r+")
    buf = mmap(f.fileno(), 0)
    lines = 0
    readline = buf.readline
    while readline():
        lines += 1
    return lines


def get_training_size(dataset: str):
    file_csv = os.path.join(PROJECT_DIR, 'splits', dataset, 'participants_train_kfold_0.csv')
    return mapcount(file_csv)


In [None]:
id_reference = "kfold_0"
configuration = "LocalNode" #or FedAvg, FedProx, Yogi, FedAdam, FedAdagrad, Centralized
configuration_file = "config.ini"
config = configparser.ConfigParser()
config.read(configuration_file)
config_args = dict(config.items(configuration))
config_args['configuration'] = configuration

tags = [config_args['tags'], id_reference]
FEDBIOMED_DIR = ""
if configuration == "LocalNode":
    tags.append('Local0') # or local name, has to match the tag used in the node


In [None]:
training_args = {'batch_size': int(config_args['batch_size']) if 'batch_size' in config_args.keys() else 8,
                  'dry_run': False,
                  'log_interval': int(config_args['log_interval']) if 'log_interval' in config_args.keys() else 3,
                  'test_ratio': float(config_args['test_ratio']) if 'test_ratio' in config_args.keys() else 0.1,
                  'test_on_global_updates': False, 'test_on_local_updates': False,
                  'num_updates': int(config_args['num_updates']) if 'num_updates' in config_args.keys() else None,
                  'optimizer_args': {
                      'lr': float(config_args['lr']) if 'lr' in config_args.keys() else 0.001,
                  },
                  'use_gpu': True if torch.cuda.is_available() else False,
                  }

model_args = {
    'use_gpu': True if torch.cuda.is_available() else False,
    'dropout': float(config_args['dropout']) if 'dropout' in config_args.keys() else 0,
    'optimizer_name': config_args['optimizer_name'],
    'network_type': 'segresnet'

}

In [None]:
if 'num_upgrades_local_run' in config_args.keys():
    local_dim = get_training_size(config_args['used_datasets'])
    training_args['epochs'] = int(
        int(config_args['num_upgrades_local_run']) * training_args['batch_size'] / local_dim) + 1
elif 'epochs' in config_args.keys():
    training_args['epochs'] = int(config_args['epochs'])

if configuration == 'scaffold':
    aggregator = Scaffold(server_lr=1)

else:
    aggregator = FedAverage()
    if configuration == 'FedProx':
        training_args['fedprox_mu'] = float(config_args['fedprox_mu'])
if config_args['optimizer_name'] == 'sgd':
    training_args['optimizer_args']['momentum'] = float(config_args['momentum']) \
        if 'momentum' in config_args.keys() else 0.
num_rounds = get_num_rounds(configuration, config_args)
config_args['rounds'] = num_rounds

In [None]:
exp = Experiment(tags=tags,
                  model_args=model_args,
                  training_plan_class=UNetTrainingPlan,
                  training_args=training_args,
                  round_limit=num_rounds,
                  aggregator=aggregator,
                  tensorboard=False, # or True
                  save_breakpoints=False, #or True
                  )
if model_args['network_type'] == 'segresnet':
    logger.info('loading pre-trained model')
    model_params = torch.load(f"{PROJECT_DIR}/pretraining_params.pt", map_location=torch.device('cpu'))
    exp.training_plan().set_model_params(model_params)
    exp._job.update_parameters()

In [None]:
if configuration == 'FedYogi':
    exp.set_agg_optimizer(Optimizer(lr=.9, modules=[YogiModule()]))
elif configuration == 'FedAdagrad':
    exp.set_agg_optimizer(Optimizer(lr=.9, modules=[AdaGradModule()]))
elif configuration == 'FedAdam':
    exp.set_agg_optimizer(Optimizer(lr=.9, modules=[AdamModule()]))

In [None]:
dt1 = datetime.now(tz=pytz.timezone('Europe/Rome'))
exp.run()
dt2 = datetime.now(tz=pytz.timezone('Europe/Rome'))
config_args['training_time'] = (dt2 - dt1).total_seconds()

In [None]:
model.load_state_dict(exp.aggregated_params()[num_rounds - 1]['params'])
config_args['n_parameters'] = sum(param.numel() for param in model.parameters())
config_args['n_trainable_parameters'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Saving model...")

if "dp" in config_args.keys() and eval(config_args['dp']) is True:
    results_folder = f"{config_args['results_folder']}/{id_reference}_dp"
else:
    results_folder = f"{config_args['results_folder']}/{id_reference}"

folder_name = ""

if os.path.exists(os.path.join(config_args['results_folder'], id_reference, folder_name)):
    shutil.rmtree(os.path.join(config_args['results_folder'], id_reference, folder_name))
os.makedirs(os.path.join(config_args['results_folder'], id_reference, folder_name))

torch.save(model.unet.state_dict(), f'{results_folder}/{folder_name}/unet')