In [None]:
%load_ext autoreload
%autoreload 2

### Fed-IXI

In [None]:
#from fedbiomed.common.data import DataManager
from fedbiomed.common.training_plans import TorchTrainingPlan
import torch.nn as nn
from torch.optim import AdamW
from unet import UNet

class UNetTrainingPlan(TorchTrainingPlan):
    # Init of UNetTrainingPlan
    def __init__(self, model_args: dict = {}):
        super(UNetTrainingPlan, self).__init__(model_args)
        self.CHANNELS_DIMENSION = 1
        
        self.unet = UNet(
            in_channels = model_args.get('in_channels',1),
            out_classes = model_args.get('out_classes',2),
            dimensions = model_args.get('dimensions',2),
            num_encoding_blocks = model_args.get('num_encoding_blocks',5),
            out_channels_first_layer = model_args.get('out_channels_first_layer',64),
            normalization = model_args.get('normalization', None),
            pooling_type = model_args.get('pooling_type', 'max'),
            upsampling_type = model_args.get('upsampling_type','conv'),
            preactivation = model_args.get('preactivation',False),
            residual = model_args.get('residual',False),
            padding = model_args.get('padding',0),
            padding_mode = model_args.get('padding_mode','zeros'),
            activation = model_args.get('activation','ReLU'),
            initial_dilation = model_args.get('initial_dilation',None),
            dropout = model_args.get('dropout',0),
            monte_carlo_dropout = model_args.get('monte_carlo_dropout',0)
        )
        
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["import torch.nn as nn",
               'import torch.nn.functional as F',
               'from torch.optim import AdamW',
               'from unet import UNet']
        self.add_dependency(deps)
        
        self.optimizer = AdamW(self.parameters())

    def forward(self, x):
        x = self.unet.forward(x)
        x = F.softmax(x, dim=self.CHANNELS_DIMENSION)
        return x
    
    @staticmethod
    def get_dice_loss(output, target, epsilon=1e-9):
        CHANNELS_DIMENSION = 1
        SPATIAL_DIMENSIONS = 2, 3, 4
        p0 = output
        g0 = target
        p1 = 1 - p0
        g1 = 1 - g0
        tp = (p0 * g0).sum(dim=SPATIAL_DIMENSIONS)
        fp = (p0 * g1).sum(dim=SPATIAL_DIMENSIONS)
        fn = (p1 * g0).sum(dim=SPATIAL_DIMENSIONS)
        num = 2 * tp
        denom = 2 * tp + fp + fn + epsilon
        dice_score = num / denom
        return 1. - dice_score
    
    def training_step(self, data, target):
        #this function must return the loss to backward it 
        img = data
        output = self.forward(img)
        loss = UNetTrainingPlan.get_dice_loss(output, target)
        avg_loss = loss.mean()
        return avg_loss
    
    def testing_step(self, data, target):
        img = data
        prediction = self.forward(img)
        loss = UNetTrainingPlan.get_dice_loss(prediction, target)
        avg_loss = loss.mean()  # average per batch
        return avg_loss

In [None]:
model_args = {
    'in_channels': 1,
    'out_classes': 2,
    'dimensions': 3,
    'num_encoding_blocks': 3,
    'out_channels_first_layer': 8,
    'normalization': 'batch',
    'upsampling_type': 'linear',
    'padding': True,
    'activation': 'PReLU',
}

training_args = {
    'batch_size': 16,
    'lr': 0.001,
    'epochs': 1,
    'dry_run': False,
    'log_interval': 2,
    'test_ratio' : 0.0,
    'test_on_global_updates': False,
    'test_on_local_updates': False,
}



*train_transform_flamby* key in **training_args** can optionally be used to perform extra transformations on a flamby dataset.

As a reminder, flamby datasets are already internally handling a transformation through their dataloader (this internal transform
is the one officially used for the flamby benchmark). Thus, one should check what is already performed on the flamby side before
adding transforms through the researcher.

*train_transform_flamby* has to be defined as a list containing two elements:
- the first is the imports needed to perform the transformation
- the second is the Compose object that will be used to input the transform parameter of the flamby dataset federated class

Example:
```python
training_args = {
    ...,
    'train_transform_flamby':["from monai.transforms import (Compose, NormalizeIntensity, Resize,)",
                         "Compose([Resize((48,60,48)), NormalizeIntensity()])"]
}
```

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['ixi']
num_rounds = 1

exp = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=UNetTrainingPlan,
                 training_args=training_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                )

In [None]:
exp.run()

### Fed-Heart-Disease

In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from torch.optim import Adam

class FedHeartTrainingPlan(TorchTrainingPlan):
    # Init of FedHeartTrainingPlan
    def __init__(self, model_args: dict = {}):
        super(FedHeartTrainingPlan, self).__init__(model_args)
        
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["import torch.nn as nn",
               'from torch.optim import Adam',]
        self.add_dependency(deps)
        
        self.model = nn.Linear(13, 1)
        self.optimizer = Adam(self.parameters())

    def forward(self, x):
        return torch.sigmoid(self.model(x))
    
    def training_step(self, data, target):
        #this function must return the loss to backward it 
        output = self.forward(data)
        bce = nn.BCELoss()
        loss = bce(output, target)
        return loss

In [None]:
training_args = {
    'batch_size': 16,
    'lr': 0.001,
    'epochs': 5,
    'dry_run': False,
    'log_interval': 2,
    'test_ratio' : 0.0,
    'test_on_global_updates': False,
    'test_on_local_updates': False,
}

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['dd']
num_rounds = 1

exp = Experiment(tags=tags,
                 model_class=FedHeartTrainingPlan,
                 training_args=training_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                )

In [None]:
exp.run()