In [None]:
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy
from fedbiomed.common.constants import ErrorNumbers
from fedbiomed.common.exceptions import FedbiomedStrategyError
from fedbiomed.common.logger import logger


class MedicalFolderStrategy(DefaultStrategy):
    def __init__(self, data, modalities=['T1']):
        super().__init__(data)
        self._modalities = modalities

    def refine(self, training_replies, round_i):
        models_params = []
        weights = []

        # check that all nodes answered
        cl_answered = [val['node_id'] for val in training_replies.data()]

        answers_count = 0
        for cl in self.sample_nodes(round_i):
            if cl in cl_answered:
                answers_count += 1
            else:
                # this node did not answer
                logger.error(f'{ErrorNumbers.FB408.value} (node = {cl})')

        if len(self.sample_nodes(round_i)) != answers_count:
            if answers_count == 0:
                # none of the nodes answered
                msg = ErrorNumbers.FB407.value

            else:
                msg = ErrorNumbers.FB408.value

            logger.critical(msg)
            raise FedbiomedStrategyError(msg)

        # check that all nodes that answer could successfully train
        self._success_node_history[round_i] = []
        all_success = True
        for tr in training_replies:
            if tr['success'] is True:
                model_params = {tr['node_id']: tr['params']}
                models_params.append(model_params)
                self._success_node_history[round_i].append(tr['node_id'])
            else:
                # node did not succeed
                all_success = False
                logger.error(f'{ErrorNumbers.FB409.value} (node = {tr["node_id"]})')

        if not all_success:
            raise FedbiomedStrategyError(ErrorNumbers.FB402.value)

        # so far, everything is OK
        shapes = [val[0]["shape"]['demographics'][0] for (key, val) in self._fds.data().items()]
        total_rows = sum(shapes)
        weights = [{key: val[0]["shape"]['demographics'][0] / total_rows} for (key, val) in self._fds.data().items()]
        logger.info('Nodes that successfully reply in round ' +
            str(round_i) + ' ' +
            str(self._success_node_history[round_i]))
        return models_params, weights

In [None]:
import numpy as np
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager, MedicalFolderDataset
import torch
from torch import nn
from torch.optim import AdamW, SGD
from monai.losses.dice import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import Compose, SpatialPad, CenterSpatialCrop, NormalizeIntensity, AddChannel, AsDiscrete, Lambda
import torch.nn.functional as F


class UNetTrainingPlan(TorchTrainingPlan):
    # Init of UNetTrainingPlan

    def init_model(self, model_args):
        model = self.Net(model_args)
        return model

    def init_optimizer(self, optimizer_args):
        tmp_args = self.model().model_arguments
        optimizer_name = tmp_args['optimizer_name'] if 'optimizer_name' in tmp_args.keys() else 'adam'
        if optimizer_name == 'adam':
            optimizer = AdamW(self.model().parameters(), **optimizer_args)
        elif optimizer_name == 'sgd':
            optimizer = torch.optim.SGD(self.model().parameters(), **optimizer_args)
        else:
            return -1
        return optimizer

    def init_dependencies(self):
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = [
            "from monai.transforms import Compose, SpatialPad, CenterSpatialCrop, NormalizeIntensity, "
            "AddChannel, Resize, AsDiscrete, Lambda",
            "import torch.nn as nn",
            "import torch",
            "from monai.losses.dice import DiceLoss",
            "import torch.nn.functional as F",
            "from fedbiomed.common.data import MedicalFolderDataset",
            "import numpy as np",
            "from torch.optim import AdamW",
            "from monai.networks.nets import UNet"
        ]
        return deps

    class Net(nn.Module):
        # Init of UNetTrainingPlan
        def __init__(self, model_args: dict = {}):
            super().__init__()
            self.CHANNELS_DIMENSION = 1
            self.unet = UNet(
                spatial_dims=3,
                in_channels=1,
                out_channels=2,
                channels=(16, 32, 64, 128, 256),
                strides=(2, 2, 2, 2),
                num_res_units=model_args['num_res_units'],
                norm="batch",
                dropout=model_args['dropout'],
            )
            self.model_arguments = model_args

        def forward(self, x):
            x = self.unet.forward(x)
            x = F.softmax(x, dim=self.CHANNELS_DIMENSION)
            return x

    @staticmethod
    def get_dice_loss(output, target):
        loss = DiceLoss(include_background=False, sigmoid=False)
        loss_value = loss(output, target)
        return loss_value

    @staticmethod
    def demographics_transform(demographics):
        if isinstance(demographics, dict) and len(demographics) == 0:
            # when input is empty dict, we don't want to transform anything
            return demographics

        # simple example: keep only some keys
        out = np.array([float(val) for key, val in demographics.items()])
        return out

    def training_data(self, batch_size=4):
        # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        common_shape = (320, 320, 16)

        training_transform = Compose([AddChannel(), CenterSpatialCrop(common_shape),
                                      SpatialPad(common_shape), NormalizeIntensity()])
        target_transform = Compose([AddChannel(), CenterSpatialCrop(common_shape), SpatialPad(common_shape),
                                    Lambda(func=lambda x: torch.where(x != 0, 1, 0)),
                                    AsDiscrete(to_onehot=2)
                                    ])

        dataset = MedicalFolderDataset(
            root=self.dataset_path,
            data_modalities='image',
            target_modalities='label',
            transform=training_transform,
            target_transform=target_transform,
            demographics_transform=UNetTrainingPlan.demographics_transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        return DataManager(dataset, **train_kwargs)

    def training_step(self, data, target):
        # this function must return the loss to backward it
        img = data[0]['image']
        target = target['label']
        output = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(output, target)
        avg_loss = loss.mean()
        return avg_loss

    def testing_step(self, data, target):
        img = data[0]['image']
        target = target['label']
        prediction = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(prediction, target)
        avg_loss = loss.mean()  # average per batch
        return avg_loss


In [None]:
def get_num_rounds(configuration: str, config_args: dict):
    num_updates = int(config_args['num_updates']) if 'num_updates' in config_args.keys() else 20
    num_epochs_local = int(config_args['num_epochs_local']) if 'num_epochs_local' in config_args.keys() else 100
    num_clients = int(config_args['num_clients']) if 'num_clients' in config_args.keys() else 4
    batch_size = int(config_args['batch_size']) if 'batch_size' in config_args.keys() else 8
    dim_training_set = 210 if num_clients == 3 else 230
    if configuration in ['LocalNode', 'centralized']:
        return 1
    else:
        return (dim_training_set // num_clients // batch_size) * num_epochs_local // num_updates


def get_training_size(dataset: str):
    DS_SIZES = {'decathlon': 32,
                'promise_no_coil': 23,
                'promise_coil': 27,
                'prostatex_skyra': 174
                }
    return DS_SIZES[dataset]


In [None]:
id_reference = "kfold_0"
configuration = "LocalNode" #or FedAvg, FedProx, Yogi, FedAdam, FedAdagrad, Centralized
configuration_file = "config.ini"
config = configparser.ConfigParser()
config.read(configuration_file)
config_args = dict(config.items(configuration))
config_args['configuration'] = configuration

In [None]:
model_args = {
    'use_gpu': True,
    'dropout': float(config_args['dropout']) if 'dropout' in config_args.keys() else 0,
    'num_res_units': int(config_args['num_res_units']),
    'optimizer_name': config_args['optimizer_name'],

}

if 'num_upgrades_local_run' in config_args.keys():
    local_dim = get_training_size(config_args['used_datasets'])
    training_args['epochs'] = int(
        int(config_args['num_upgrades_local_run']) * training_args['batch_size'] / local_dim)
elif 'epochs' in config_args.keys():
    training_args['epochs'] = int(config_args['epochs'])

if configuration == 'scaffold':
    aggregator = Scaffold(server_lr=1)
else:
    aggregator = FedAverage()
    if configuration == 'FedProx':
        training_args['fedprox_mu'] = float(config_args['fedprox_mu'])
if config_args['optimizer_name'] == 'sgd':
    training_args['optimizer_args']['momentum'] = float(config_args['momentum']) \
        if 'momentum' in config_args.keys() else 0.
num_rounds = get_num_rounds(configuration, config_args)
config_args['rounds'] = num_rounds
if "dp" in config_args.keys() and eval(config_args['dp']) is True:
    sigma = 4.
    clip = 1.
    LDP = {'dp_args': {'type': 'local', 'sigma': sigma, 'clip': clip}}
    model_args.update(LDP)

In [None]:
exp = Experiment(tags=tags,
                  model_args=model_args,
                  training_plan_class=UNetTrainingPlan,
                  training_args=training_args,
                  round_limit=num_rounds,
                  aggregator=aggregator,
                  tensorboard=False, # or True
                  save_breakpoints=False) # or True
if configuration == 'Yogi':
    exp.set_agg_optimizer(Optimizer(lr=.9, modules=[YogiModule()]))
elif configuration == 'FedAdagrad':
    exp.set_agg_optimizer(Optimizer(lr=.9, modules=[AdaGradModule()]))
elif configuration == 'FedAdam':
    exp.set_agg_optimizer(Optimizer(lr=.9, modules=[AdamModule()]))

In [None]:
dt1 = datetime.now(tz=pytz.timezone('Europe/Rome'))
exp.run()
dt2 = datetime.now(tz=pytz.timezone('Europe/Rome'))
config_args['training_time'] = (dt2 - dt1).total_seconds()
config_args['exp.training_plan_file()'] = exp.training_plan_file(display=False)
try:
    model = exp.model_instance()
except AttributeError:
    model = exp.training_plan().model()
except:
    print('Model not find')
model.load_state_dict(exp.aggregated_params()[num_rounds - 1]['params'])
config_args['n_parameters'] = sum(param.numel() for param in model.parameters())
config_args['n_trainable_parameters'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Saving model...")

if "dp" in config_args.keys() and eval(config_args['dp']) is True:
    results_folder = f"{config_args['results_folder']}/{id_reference}_dp"
else:
    results_folder = f"{config_args['results_folder']}/{id_reference}"

folder_name = ""

if os.path.exists(os.path.join(config_args['results_folder'], id_reference, folder_name)):
    shutil.rmtree(os.path.join(config_args['results_folder'], id_reference, folder_name))
os.makedirs(os.path.join(config_args['results_folder'], id_reference, folder_name))

torch.save(model.unet.state_dict(), f'{results_folder}/{folder_name}/unet')
