## Install dependencies

If you haven't done so already, install the additional dependencies required by the flamby datasets/features that you intend on using. 

You may check out which dependencies are needed by each dataset directly from Flamby's `setup.py` [file](https://github.com/owkin/FLamby/blob/main/setup.py#L42). In our case we'll be using the federated IXI and federated heart disease datasets, hence we'll need wget, monai and nibabel. 

You need to download the FLamby dataset that we will use. For licensing reasons, these are not included directly in the FLamby installation.

To download the fed_heart dataset in `${FEDBIOMED_DIR}/data` (where `${FEDBIOMED_DIR}` is the base directory of Fed-BioMed): 

1. `source ${FEDBIOMED_DIR}/scripts/fedbiomed_environment node`<br>
2. `pip install wget`<br>
3. `python ${FEDBIOMED_DIR}docs/tutorials/sec-agg/fed-ixi/download_fead_ixi.py --output-folder ${FEDBIOMED_DIR}/data`

In [1]:
! pip install wget nibabel  # monai comes already packaged within fed-biomed



In [2]:
from fedbiomed.common.training_plans import TorchTrainingPlan
from flamby.datasets.fed_ixi import Baseline, BaselineLoss, Optimizer
from fedbiomed.common.data import FlambyDataset, DataManager
from torch.utils.data import DataLoader
import numpy as np

import os
FEDBIOMED_DIR = os.getenv('FEDBIOMED_DIR')


from flamby.datasets.fed_ixi import FedIXITiny
DATASET_TEST_PATH = f"{FEDBIOMED_DIR}/data"


class MyTrainingPlan(TorchTrainingPlan):
    def init_model(self, model_args):
        return Baseline()

    def init_optimizer(self, optimizer_args):
        return Optimizer(self.model().parameters(), lr=optimizer_args["lr"])

    def init_dependencies(self):
        return ["from flamby.datasets.fed_ixi import Baseline, BaselineLoss, Optimizer",
                "from fedbiomed.common.data import FlambyDataset, DataManager"]

    def training_step(self, data, target):
        output = self.model().forward(data)
        return BaselineLoss().forward(output, target)

    def training_data(self):
        dataset = FlambyDataset()
        loader_arguments = { 'shuffle': True}
        return DataManager(dataset, **loader_arguments)

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
count_parameters(Baseline())

246156

In [4]:
batch_size = 2
num_updates = 10
num_rounds = 75

In [5]:
model_args = {}

training_args = {
    'loader_args': { 'batch_size': batch_size, },
    'optimizer_args': {
        "lr" : 1e-3
    },
    'num_updates': num_updates,
    'dry_run': False,
    'batch_maxnum': 2, # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
    'random_seed':42,
    'log_interval': 10,
    'share_persistent_buffers': False,
}

In [6]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
from fedbiomed.researcher.secagg import SecureAggregation

tags =  ['flixi']

In [7]:
exp_no_sec_agg = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 model_args=model_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage())
exp_no_sec_agg.set_retain_full_history(True)

2024-05-27 15:35:03,550 fedbiomed INFO - Starting researcher service...

2024-05-27 15:35:03,552 fedbiomed INFO - Waiting 3s for nodes to connect...

2024-05-27 15:35:03,846 fedbiomed DEBUG - Node: NODE_7849b47e-12fc-4062-a85a-706aa967e395 polling for the tasks

2024-05-27 15:35:04,730 fedbiomed DEBUG - Node: NODE_1ab3470a-b961-4de6-a454-237155e67b4c polling for the tasks

2024-05-27 15:35:04,731 fedbiomed DEBUG - Node: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 polling for the tasks

2024-05-27 15:35:06,559 fedbiomed INFO - Updating training data. This action will update FederatedDataset, and the nodes that will participate to the experiment.

2024-05-27 15:35:06,572 fedbiomed DEBUG - Node: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 polling for the tasks

2024-05-27 15:35:06,573 fedbiomed DEBUG - Node: NODE_7849b47e-12fc-4062-a85a-706aa967e395 polling for the tasks

2024-05-27 15:35:06,574 fedbiomed DEBUG - Node: NODE_1ab3470a-b961-4de6-a454-237155e67b4c polling for the tasks

2024-05-27 15:35:06,579 fedbiomed INFO - Node selected for training -> NODE_7849b47e-12fc-4062-a85a-706aa967e395

2024-05-27 15:35:06,580 fedbiomed INFO - Node selected for training -> NODE_1ab3470a-b961-4de6-a454-237155e67b4c

2024-05-27 15:35:06,581 fedbiomed INFO - Node selected for training -> NODE_efae9061-c6e3-4c90-87eb-ba8775edd565

2024-05-27 15:35:06,584 fedbiomed DEBUG - Model file has been saved: /workspaces/Projects/fedbiomed/var/experiments/Experiment_0082/model_dc67a4ff-2565-4902-84a5-f64aef8378f8.py

Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_mode`` turned on.


True

In [None]:
exp_no_sec_agg.run()

2024-05-27 15:35:06,608 fedbiomed INFO - Sampled nodes in round 0 ['NODE_7849b47e-12fc-4062-a85a-706aa967e395', 'NODE_1ab3470a-b961-4de6-a454-237155e67b4c', 'NODE_efae9061-c6e3-4c90-87eb-ba8775edd565']

2024-05-27 15:35:06,611 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:35:06,611 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:35:06,613 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:35:06,636 fedbiomed DEBUG - Node: NODE_7849b47e-12fc-4062-a85a-706aa967e395 polling for the tasks

2024-05-27 15:35:06,640 fedbiomed DEBUG - Node: NODE_1ab3470a-b961-4de6-a454-237155e67b4c polling for the tasks

2024-05-27 15:35:06,645 fedbiomed DEBUG - Node: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 polling for the tasks

					[1m NODE[0m NODE_efae9061-c6e3-4c90-87eb-ba8775edd565
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

					[1m NODE[0m NODE_1ab3470a-b961-4de6-a454-237155e67b4c
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

					[1m NODE[0m NODE_7849b47e-12fc-4062-a85a-706aa967e395
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

2024-05-27 15:35:13,409 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					 Round 1 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.558466[0m 
					 ---------

2024-05-27 15:35:14,205 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					 Round 1 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.544607[0m 
					 ---------

2024-05-27 15:35:16,279 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					 Round 1 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.540522[0m 
					 ---------

2024-05-27 15:36:10,361 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					 Round 1 | Iteration: 10/10 (100%) | Samples: 20/20
 					 Loss: [1m0.246937[0m 
					 ---------

2024-05-27 15:36:12,715 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					 Round 1 | Iteration: 10/10 (100%) | Samples: 20/20
 					 Loss: [1m0.255100[0m 
					 ---------

2024-05-27 15:36:12,944 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					 Round 1 | Iteration: 10/10 (100%) | Samples: 20/20
 					 Loss: [1m0.254484[0m 
					 ---------

2024-05-27 15:36:12,990 fedbiomed INFO - Nodes that successfully reply in round 0 ['NODE_7849b47e-12fc-4062-a85a-706aa967e395', 'NODE_1ab3470a-b961-4de6-a454-237155e67b4c', 'NODE_efae9061-c6e3-4c90-87eb-ba8775edd565']

2024-05-27 15:36:13,004 fedbiomed INFO - Sampled nodes in round 1 ['NODE_7849b47e-12fc-4062-a85a-706aa967e395', 'NODE_1ab3470a-b961-4de6-a454-237155e67b4c', 'NODE_efae9061-c6e3-4c90-87eb-ba8775edd565']

2024-05-27 15:36:13,007 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:36:13,008 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:36:13,009 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:36:13,038 fedbiomed DEBUG - Node: NODE_7849b47e-12fc-4062-a85a-706aa967e395 polling for the tasks

2024-05-27 15:36:13,044 fedbiomed DEBUG - Node: NODE_1ab3470a-b961-4de6-a454-237155e67b4c polling for the tasks

2024-05-27 15:36:13,048 fedbiomed DEBUG - Node: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 polling for the tasks

					[1m NODE[0m NODE_7849b47e-12fc-4062-a85a-706aa967e395
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

					[1m NODE[0m NODE_efae9061-c6e3-4c90-87eb-ba8775edd565
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

					[1m NODE[0m NODE_1ab3470a-b961-4de6-a454-237155e67b4c
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

2024-05-27 15:36:21,421 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					 Round 2 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.261658[0m 
					 ---------

2024-05-27 15:36:21,935 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					 Round 2 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.256751[0m 
					 ---------

2024-05-27 15:36:23,564 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					 Round 2 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.256497[0m 
					 ---------

2024-05-27 15:37:16,250 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					 Round 2 | Iteration: 10/10 (100%) | Samples: 20/20
 					 Loss: [1m0.227862[0m 
					 ---------

2024-05-27 15:37:19,306 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					 Round 2 | Iteration: 10/10 (100%) | Samples: 20/20
 					 Loss: [1m0.239552[0m 
					 ---------

2024-05-27 15:37:19,404 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					 Round 2 | Iteration: 10/10 (100%) | Samples: 20/20
 					 Loss: [1m0.238835[0m 
					 ---------

2024-05-27 15:37:19,465 fedbiomed INFO - Nodes that successfully reply in round 1 ['NODE_7849b47e-12fc-4062-a85a-706aa967e395', 'NODE_1ab3470a-b961-4de6-a454-237155e67b4c', 'NODE_efae9061-c6e3-4c90-87eb-ba8775edd565']

2024-05-27 15:37:19,477 fedbiomed INFO - Sampled nodes in round 2 ['NODE_7849b47e-12fc-4062-a85a-706aa967e395', 'NODE_1ab3470a-b961-4de6-a454-237155e67b4c', 'NODE_efae9061-c6e3-4c90-87eb-ba8775edd565']

2024-05-27 15:37:19,479 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:37:19,480 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:37:19,481 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-05-27 15:37:19,512 fedbiomed DEBUG - Node: NODE_7849b47e-12fc-4062-a85a-706aa967e395 polling for the tasks

2024-05-27 15:37:19,516 fedbiomed DEBUG - Node: NODE_1ab3470a-b961-4de6-a454-237155e67b4c polling for the tasks

2024-05-27 15:37:19,520 fedbiomed DEBUG - Node: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 polling for the tasks

					[1m NODE[0m NODE_7849b47e-12fc-4062-a85a-706aa967e395
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

					[1m NODE[0m NODE_efae9061-c6e3-4c90-87eb-ba8775edd565
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

					[1m NODE[0m NODE_1ab3470a-b961-4de6-a454-237155e67b4c
					[1m MESSAGE:[0m Both batch_maxnum and num_updates specified. batch_maxnum will be ignored.[0m
-----------------------------------------------------------------

2024-05-27 15:37:28,294 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_7849b47e-12fc-4062-a85a-706aa967e395 
					 Round 3 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.240779[0m 
					 ---------

2024-05-27 15:37:28,839 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_1ab3470a-b961-4de6-a454-237155e67b4c 
					 Round 3 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.233208[0m 
					 ---------

2024-05-27 15:37:29,897 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_efae9061-c6e3-4c90-87eb-ba8775edd565 
					 Round 3 | Iteration: 1/10 (10%) | Samples: 2/20
 					 Loss: [1m0.242702[0m 
					 ---------

In [None]:
def print_training_rounds(exp, num_rounds):
    # List the training rounds
    print("\nList the training rounds:", exp.training_replies().keys())

    # Iterate over each training round and print details for each node
    print("\nList the nodes for the last training round and their timings:")
    for rnd in range(num_rounds):
        round_data = exp.training_replies()[rnd]
        for r in round_data.values():
            print("\t- {id} :\
            \n\t\trtime_training={rtraining:.2f} seconds\
            \n\t\tptime_training={ptraining:.2f} seconds\
            \n\t\ttime_encrypt={time_encrypt:.2f} seconds\
            \n\t\tptxt_model_size={ptxt_model_size:.2f} MB\
            \n\t\tctxt_model_size={ctxt_model_size:.2f} MB\
            \n\t\trtime_total={rtotal:.2f} seconds".format(
                id=r['node_id'],
                rtraining=r['timing']['rtime_training'],
                ptraining=r['timing']['ptime_training'],
                time_encrypt=r['timing']['time_encrypt'],
                ptxt_model_size=r['communication']['ptxt_model_size'],
                ctxt_model_size=r['communication']['ctxt_model_size'],
                rtotal=r['timing']['rtime_total']
            ))
        print('\n')

In [None]:
print_training_rounds(exp_no_sec_agg, num_rounds=40)

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader

def metric(y_true, y_pred):
    """
    Soft Dice coefficient
    """
    SPATIAL_DIMENSIONS = (2, 3, 4)
    intersection = (y_pred * y_true).sum(axis=SPATIAL_DIMENSIONS)
    union = (0.5 * (y_pred + y_true)).sum(axis=SPATIAL_DIMENSIONS)
    dice = intersection / (union + 1.0e-7)
    # If both inputs are empty the dice coefficient should be equal to 1
    dice[union == 0] = 1
    return np.mean(dice)

def test_soft_dice(net, test_loader):
    """Test the network: measure Soft Dice coefficient on the test set."""

    # Freeze normalization layers
    net.eval()

    all_y_pred = []
    all_targets = []

    # Iterate over the batches
    for data, target in test_loader:
        # Accumulate the ground truth labels
        all_targets.append(target.numpy())

        output = net(data).detach().numpy()
        all_y_pred.append(output)

    # Convert lists to numpy arrays for metric computation
    all_y_pred = np.concatenate(all_y_pred, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    # Compute the Soft Dice coefficient
    dice_coefficient = metric(all_targets, all_y_pred)
    print(f"Test Soft Dice coefficient: {dice_coefficient:.4f}")

In [None]:
test_dataset = FedIXITiny(center=0,pooled=True, train=False, data_path=DATASET_TEST_PATH)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
def evaluate_model(exp, num_rounds, test_dataloader):
    for r in range(num_rounds):
        fed_model = exp.training_plan().model()
        fed_model.load_state_dict(exp.aggregated_params()[r]['params'])
        test_soft_dice(fed_model, test_dataloader)



In [None]:
evaluate_model(exp_no_sec_agg, num_rounds, test_dataloader)

In [None]:
# CLIPPING_RANGE: int = 10
# TARGET_RANGE: int = 2**20
# WEIGHT_RANGE: int = 2**8
# There is some layer which has 10 has value, then max num of samples <2**8    
exp_sec_agg_lom = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 model_args=model_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                 secagg=SecureAggregation(active=True, scheme='flamingo'))
exp_sec_agg_lom.set_retain_full_history(True)

In [None]:
exp_sec_agg_lom.run()

In [None]:
print_training_rounds(exp_sec_agg_lom, num_rounds=num_rounds)

In [None]:
evaluate_model(exp=exp_sec_agg_lom, num_rounds=num_rounds, test_dataloader=test_dataloader)

In [None]:

exp_sec_agg_jls = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 model_args=model_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                 secagg=SecureAggregation(active=True, scheme='jls'))
exp_sec_agg_jls.set_retain_full_history(True)

In [None]:
exp_sec_agg_jls.run()

In [None]:
print_training_rounds(exp_sec_agg_jls, num_rounds=num_rounds)

In [None]:
evaluate_model(exp=exp_sec_agg_jls, num_rounds=num_rounds, test_dataloader=test_dataloader)