# Transfer-learning tutorial using DenseNet-121 pre-trained model: example on MedNIST dataset


## Goal of this tutoriel

This tutorial shows how to do 2d images classification example on MedNIST dataset using pretrained PyTorch model.

The goal of this tutorial is to provide an example of transfer learning methods with Fed-BioMed for medical images classification.

### About the model

The model used is Densenet-121 model(“Densely Connected Convolutional Networks”) pretrained on ImageNet dataset. The Pytorch pretrained model [Densenet121](https://pytorch.org/vision/main/models/generated/torchvision.models.html). to perform image classification on the MedNIST dataset. 
The goal of this Densenet121 model is to predict the class of `MedNIST` medical images.



### About MedNIST

MedNIST provides an artificial 2d classification dataset created by gathering different medical imaging datasets from TCIA, the RSNA Bone Age Challenge, and the NIH Chest X-ray dataset. The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) under the Creative Commons CC BY-SA 4.0 license.

MedNIST dataset is downloaded from the resources provided by the project [MONAI](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz)

The dataset MedNIST has 58954 images of size (3, 64, 64) distributed into 6 classes (10000 images per class except for BreastMRI class which has 8954 images). Classes are AbdomenCT, BreastMRI, CXR, ChestCT, Hand, HeadCT. It has the structure:

└── MedNIST/

    ├── AbdomenCT/

    └── BreastMRI/

    └── CXR/

    └── ChestCT/

    └── Hand/

    └── HeadCT/   
   

## Transfer-learning
Transfer learning is a machine learning technique where a model trained on one task is repurposed or adapted for a second related task. Transfer learning uses a pre-trained neural network on a large dataset, as [Imagenet](https://www.image-net.org) is used to train DenseNet model to perform classification of a wide diversity of images.

The objective is that the knowledge gained from learning one task can be useful for learning another task (as we do here, the knowledge of DenseNet model trained on ImageNet is used to classify medical images in 6 categories). This is particularly beneficial when the amount of labeled data for the target task is limited, as the pre-trained model has already learned useful features and representations from a large dataset.

Transfer learning is typically applied in one of two ways:

- (I) Feature Extraction: In this approach, the pre-trained model is used as a fixed feature extractor. The earlier layers of the neural network, which capture general features and patterns, are frozen, and only the later layers are replaced or retrained for the new task. 

- (II) Fine-tuning: In this approach, the pre-trained model is further trained or partially trained on the new task. This allows the model to adapt its learned representations to the specifics of the new task while retaining some of the knowledge gained from the original task.


In this example, we load on the node a sampled dataset ( 500 or 1000 images) of MedNIST to illustrate the effectiveness of the transfer-learning. The sampled dataset is made with a random selection of images and return a sampled dataset with balanced classes, to avoid classification's bias.
We will test these two approches through two independant TrainingPlan experiments and compare with TRainingPlan0 which is the experiment on non-trained model.
To illustrate the effectiveness of these two methods, we load 500 images for the first experiment and 1000 images for the second. The more data you have, the more layers's you can unfreeze for a transfer learning task. 

We will compare these two methods with the TrainingPlan0 experiment which is the basic experiment running on the untrained DenseNet model. We will focus on loss value and accuracy as metrics to evaluate the effectiveness of Transfer-learning methods. 

### 1. Load dataset or sampled dataset
- From the root directory of Fed-BioMed, run :  `source ./scripts/fedbiomed_environment node` in order to load the Node environment
- If you have already ran Mednist nodes before, clean remaining MedNIST nodes : run `./scripts/fedbiomed_run node delete` or `source ./scripts/fedbiomed_environment clean`
- In this new environment, run the script python: `python ./notebooks/transfer-learning/download_sample_of_mednist.py -n <number-of-nodes>`, with `<number-of-nodes>` the number of Nodes you want to create( for more details about this script, please run `notebooks/transfer-learning/download_sample_of_mednist.py --help`)
- The script will ask for each Nodes created the number of samples you want for your dataset. Scripts will output configuration files for each of Nodes, with configured database, using the following naming convention: `config_mednist_<i>_sampled.ini` where `<i>` is ranged from 1 to `<number-of-nodes>` entered.  
- Finally launch your Nodes (one by terminal) by running: `./scripts/fedbiomed_run node config  start config_mednist_<i>_sampled.ini start`, where `<i>` corresponds to the number of Node created.  Wait until you get Starting task manager.

For example, if one wants to create 2 nodes, (`<i>` is equal to 2), one has to run : `python ./notebooks/transfer-learning/download_sample_of_mednist.py -n 2`. One will then launch in seperated terminal `./scripts/fedbiomed_run node config config_mednist_1_sampled.ini start` and `./scripts/fedbiomed_run node config config_mednist_2_sampled.ini start`. Script will ask how many sample should contain the dataset (enter 500 and then 1000).



### 2. Launch the researcher 
- From the root directory of Fed-BioMed, run : `./scripts/fedbiomed_run researcher start`
- It opens the Jupyter notebook.

To make sure that MedNIST dataset is loaded in the node we can send a request to the network to list the available dataset in the node. The list command should output an entry for mednist data.

 

In [1]:
from fedbiomed.researcher.requests import Requests
req  = Requests()
req.list()

2024-02-23 10:15:18,383 fedbiomed INFO - Starting researcher service...

2024-02-23 10:15:18,384 fedbiomed INFO - Waiting 3s for nodes to connect...

2024-02-23 10:15:18,853 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:15:18,904 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:15:21,396 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:15:21,398 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

{'config_mednist_2_sampled': [{'name': 'MedNIST_2_sampled',
   'data_type': 'mednist',
   'tags': ['#MEDNIST', '#dataset'],
   'description': 'MedNIST dataset for transfer learning',
   'shape': [500, 3, 64, 64],
   'dataset_id': 'dataset_1cf8c145-1dea-4867-935a-12d182be5c6d',
   'dataset_parameters': None}],
 'config_mednist_1_sampled': [{'name': 'MedNIST_1_sampled',
   'data_type': 'mednist',
   'tags': ['#MEDNIST', '#dataset'],
   'description': 'MedNIST dataset for transfer learning',
   'shape': [1000, 3, 64, 64],
   'dataset_id': 'dataset_c320ef69-75fa-4f3a-ae4a-7fbfee36d2b3',
   'dataset_parameters': None}]}

## Import of librairies 

In [2]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan

from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage


## Run an expriment for image's classification without Transfer-learning

Here we propose to run as first experiment a TrainingPlan0 with the untrained DenseNet model. Then, we will compare the loss value from the two other experiments allowing Transfer-learning methods.

We don't use the pre-trained weights. It is important to adapt learning rate. I propose you to start with lr=1e-4 and we could adapt learning rate according to the metric's evaluation. 

In [12]:
class MyTrainingPlan0(TorchTrainingPlan):

    def init_model(self, model_args):
       
        # Load the pre-trained DenseNet model, you have two ways to import your model
        
        model = models.densenet121(pretrained=False)
        
        
        # Remove the classification layer of DenseNet
        for param in model.features[:-1].parameters():
            param.requires_grad = False
            
        # add the classifier 
        num_classes = model_args['num_classes'] 
        num_ftrs = model.classifier.in_features
        model.classifier= nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
      
        return model


    def init_dependencies(self):
        return [
            "from torchvision import datasets, transforms, models",
            "import torch.optim as optim",
            "from torchvision.models import densenet121"
        ]


    def init_optimizer(self, optimizer_args):        
        return optim.Adam(self.model().parameters(), lr=optimizer_args["lr"])

    
    # training data
    
    def training_data(self):
        

        # Transform images and  do data augmentation 
        preprocess = transforms.Compose([
                transforms.Resize((224,224)),  
                #transforms.RandomHorizontalFlip(p=0.5),
                #transforms.RandomVerticalFlip(p=0.5),
                #transforms.RandomRotation(30),
                #transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
           ])
    
        train_data = datasets.ImageFolder(self.dataset_path,transform = preprocess)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=train_data, **train_kwargs)

    def training_step(self, data, target):
        output = self.model().forward(data)
        loss_func = nn.CrossEntropyLoss()
        loss   = loss_func(output, target)
        return loss




In [13]:
training_args = {
    'loader_args': { 'batch_size': 32, }, 
    'optimizer_args': {'lr': 1e-4}, 
    'epochs': 10, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

model_args = {
    'num_classes': 6 # adapt this number to the number of classes in your dataset
}

In [14]:
tags =  ['#MEDNIST', '#dataset']

rounds = 1 # adjsut the number of rounds 

exp = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan0,
                 model_args=model_args,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage())

# testing section 
from fedbiomed.common.metrics import MetricTypes
exp.set_test_ratio(.1) 
exp.set_test_on_local_updates(True)
exp.set_test_metric(MetricTypes.ACCURACY)

exp.set_tensorboard(True)

2024-02-23 10:27:19,742 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:27:19,745 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:27:19,748 fedbiomed INFO - Node selected for training -> config_mednist_2_sampled

2024-02-23 10:27:19,749 fedbiomed INFO - Node selected for training -> config_mednist_1_sampled

2024-02-23 10:27:19,752 fedbiomed DEBUG - Model file has been saved: /home/ebirgy/development/fedbiomed_github/fedbiomed/var/experiments/Experiment_0001/model_807f7176-ff1a-4096-ad0c-05a0f2458faf.py

Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_mode`` turned on.
The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.


2024-02-23 10:27:19,866 fedbiomed DEBUG - using native torch optimizer

2024-02-23 10:27:19,867 fedbiomed INFO - Removing tensorboard logs from previous experiment

2024-02-23 10:27:19,870 fedbiomed DEBUG - Experimentation training_args updated for `job`

2024-02-23 10:27:19,871 fedbiomed DEBUG - Experimentation training_args updated for `job`

2024-02-23 10:27:19,872 fedbiomed DEBUG - Experimentation training_args updated for `job`

True

In [15]:
exp.set_nodes(['config_mednist_1_sampled'])
exp.set_tags(['#MEDNIST', '#dataset'])
td = exp.training_data().data()
td.pop('config_mednist_2_sampled')
exp.set_training_data(td)

print(exp.training_data().data())

2024-02-23 10:27:24,391 fedbiomed DEBUG - Experimentation nodes filter changed, you may need to update `training_data`

2024-02-23 10:27:24,394 fedbiomed DEBUG - Experimentation tags changed, you may need to update `training_data`

2024-02-23 10:27:24,396 fedbiomed DEBUG - Training data changed, you may need to update `node_selection_strategy`

2024-02-23 10:27:24,398 fedbiomed DEBUG - Training data changed, you may need to update `job`

2024-02-23 10:27:24,399 fedbiomed DEBUG - Training data changed, you may need to update `aggregator`

{'config_mednist_1_sampled': {'name': 'MedNIST_1_sampled', 'data_type': 'mednist', 'tags': ['#MEDNIST', '#dataset'], 'description': 'MedNIST dataset for transfer learning', 'shape': [1000, 3, 64, 64], 'dataset_id': 'dataset_c320ef69-75fa-4f3a-ae4a-7fbfee36d2b3', 'dtypes': [], 'dataset_parameters': None}}


In [16]:
exp.run()

2024-02-23 10:27:25,510 fedbiomed INFO - Sampled nodes in round 0 ['config_mednist_1_sampled']

2024-02-23 10:27:25,519 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: config_mednist_1_sampled 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-02-23 10:27:25,694 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:27:27,595 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 1 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m1.858775[0m 
					 ---------

2024-02-23 10:27:44,717 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 1 | Iteration: 10/29 (34%) | Samples: 320/928
 					 Loss: [1m1.529183[0m 
					 ---------

2024-02-23 10:27:48,045 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:28:01,469 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 1 | Iteration: 20/29 (69%) | Samples: 640/928
 					 Loss: [1m1.347260[0m 
					 ---------

2024-02-23 10:28:10,431 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:28:14,589 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 1 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m1.503261[0m 
					 ---------

2024-02-23 10:28:16,181 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 2 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m1.191252[0m 
					 ---------

2024-02-23 10:28:34,206 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 2 | Iteration: 11/29 (38%) | Samples: 352/928
 					 Loss: [1m1.024164[0m 
					 ---------

2024-02-23 10:28:48,042 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:28:51,394 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 2 | Iteration: 21/29 (72%) | Samples: 672/928
 					 Loss: [1m0.940107[0m 
					 ---------

2024-02-23 10:28:52,288 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:29:09,618 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 2 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m1.138714[0m 
					 ---------

2024-02-23 10:29:10,431 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:29:12,277 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 3 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.908263[0m 
					 ---------

2024-02-23 10:29:15,030 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 3 | Iteration: 2/29 (7%) | Samples: 64/928
 					 Loss: [1m0.863130[0m 
					 ---------

2024-02-23 10:29:19,466 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:29:42,661 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 3 | Iteration: 12/29 (41%) | Samples: 384/928
 					 Loss: [1m0.755644[0m 
					 ---------

2024-02-23 10:29:52,289 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:30:06,284 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:30:10,398 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 3 | Iteration: 22/29 (76%) | Samples: 704/928
 					 Loss: [1m0.765437[0m 
					 ---------

2024-02-23 10:30:19,465 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:30:27,133 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 3 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m1.673585[0m 
					 ---------

2024-02-23 10:30:29,844 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 4 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.830858[0m 
					 ---------

2024-02-23 10:30:30,098 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:30:35,404 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 4 | Iteration: 3/29 (10%) | Samples: 96/928
 					 Loss: [1m0.788744[0m 
					 ---------

2024-02-23 10:31:03,547 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 4 | Iteration: 13/29 (45%) | Samples: 416/928
 					 Loss: [1m0.732760[0m 
					 ---------

2024-02-23 10:31:06,283 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:31:19,355 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:31:30,097 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:31:31,941 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 4 | Iteration: 23/29 (79%) | Samples: 736/928
 					 Loss: [1m0.590424[0m 
					 ---------

2024-02-23 10:31:42,257 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:31:46,597 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 4 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m1.045113[0m 
					 ---------

2024-02-23 10:31:49,336 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 5 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.697756[0m 
					 ---------

2024-02-23 10:31:57,823 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 5 | Iteration: 4/29 (14%) | Samples: 128/928
 					 Loss: [1m0.576436[0m 
					 ---------

2024-02-23 10:32:19,356 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:32:25,645 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 5 | Iteration: 14/29 (48%) | Samples: 448/928
 					 Loss: [1m0.533323[0m 
					 ---------

2024-02-23 10:32:32,244 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:32:42,254 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:32:50,144 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:32:53,753 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 5 | Iteration: 24/29 (83%) | Samples: 768/928
 					 Loss: [1m0.578839[0m 
					 ---------

2024-02-23 10:33:04,671 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 5 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m0.632976[0m 
					 ---------

2024-02-23 10:33:07,265 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 6 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.545928[0m 
					 ---------

2024-02-23 10:33:16,127 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 6 | Iteration: 5/29 (17%) | Samples: 160/928
 					 Loss: [1m0.582466[0m 
					 ---------

2024-02-23 10:33:32,244 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:33:36,328 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 6 | Iteration: 15/29 (52%) | Samples: 480/928
 					 Loss: [1m0.397152[0m 
					 ---------

2024-02-23 10:33:50,142 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:33:54,842 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 6 | Iteration: 25/29 (86%) | Samples: 800/928
 					 Loss: [1m0.405308[0m 
					 ---------

2024-02-23 10:34:02,265 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:34:02,714 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 6 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m0.398396[0m 
					 ---------

2024-02-23 10:34:05,023 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 7 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.398975[0m 
					 ---------

2024-02-23 10:34:15,406 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 7 | Iteration: 6/29 (21%) | Samples: 192/928
 					 Loss: [1m0.417327[0m 
					 ---------

2024-02-23 10:34:32,244 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:34:35,498 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 7 | Iteration: 16/29 (55%) | Samples: 512/928
 					 Loss: [1m0.495801[0m 
					 ---------

2024-02-23 10:34:53,077 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 7 | Iteration: 26/29 (90%) | Samples: 832/928
 					 Loss: [1m0.588867[0m 
					 ---------

2024-02-23 10:34:56,801 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 7 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m1.078944[0m 
					 ---------

2024-02-23 10:34:58,596 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 8 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.385518[0m 
					 ---------

2024-02-23 10:35:02,263 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:35:09,060 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 8 | Iteration: 7/29 (24%) | Samples: 224/928
 					 Loss: [1m0.557657[0m 
					 ---------

2024-02-23 10:35:14,587 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:35:31,288 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 8 | Iteration: 17/29 (59%) | Samples: 544/928
 					 Loss: [1m0.348339[0m 
					 ---------

2024-02-23 10:35:32,243 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:35:43,755 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:35:49,813 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 8 | Iteration: 27/29 (93%) | Samples: 864/928
 					 Loss: [1m0.481828[0m 
					 ---------

2024-02-23 10:35:51,975 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 8 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m0.958883[0m 
					 ---------

2024-02-23 10:35:54,145 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 9 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.515115[0m 
					 ---------

2024-02-23 10:36:09,534 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 9 | Iteration: 8/29 (28%) | Samples: 256/928
 					 Loss: [1m0.489626[0m 
					 ---------

2024-02-23 10:36:14,581 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:36:27,108 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 9 | Iteration: 18/29 (62%) | Samples: 576/928
 					 Loss: [1m0.335865[0m 
					 ---------

2024-02-23 10:36:27,914 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:36:43,753 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:36:44,417 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 9 | Iteration: 28/29 (97%) | Samples: 896/928
 					 Loss: [1m0.342275[0m 
					 ---------

2024-02-23 10:36:44,600 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 9 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m0.355475[0m 
					 ---------

2024-02-23 10:36:46,389 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 10 | Iteration: 1/29 (3%) | Samples: 32/928
 					 Loss: [1m0.568909[0m 
					 ---------

2024-02-23 10:36:56,976 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:37:00,563 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 10 | Iteration: 9/29 (31%) | Samples: 288/928
 					 Loss: [1m0.270554[0m 
					 ---------

2024-02-23 10:37:18,047 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 10 | Iteration: 19/29 (66%) | Samples: 608/928
 					 Loss: [1m0.407931[0m 
					 ---------

2024-02-23 10:37:27,912 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:37:34,834 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 Epoch: 10 | Iteration: 29/29 (100%) | Samples: 900/900
 					 Loss: [1m0.664922[0m 
					 ---------

2024-02-23 10:37:37,489 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:37:40,669 fedbiomed INFO - [1mVALIDATION ON LOCAL UPDATES[0m 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 | Iteration: 1/1 (100%) | Samples: 100/100
 					 ACCURACY: [1m0.890000[0m 
					 ---------

2024-02-23 10:37:40,860 fedbiomed INFO - Nodes that successfully reply in round 0 ['config_mednist_1_sampled']

2024-02-23 10:37:40,966 fedbiomed INFO - Saved aggregated params for round 0 in /home/ebirgy/development/fedbiomed_github/fedbiomed/var/experiments/Experiment_0001/aggregated_params_55040fb3-54cb-4fdf-b354-9b26b73627f3.mpk

1

2024-02-23 10:37:56,976 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

2024-02-23 10:38:03,360 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

### Here we see the results after 1 round of 5 epochs on 1000 images without loading pre-trained model.

2024-02-23 10:37:37,489 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:37:40,669 fedbiomed INFO - VALIDATION ON LOCAL UPDATES 
					 NODE_ID: config_mednist_1_sampled 
					 Round 1 | Iteration: 1/1 (100%) | Samples: 100/100
 					 ACCURACY: 0.890000 

In [17]:
#save the model 
exp.training_plan().export_model('./training_plan0_densenet_MedNIST')

In [18]:
from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']

In [19]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
%tensorboard --logdir "$tensorboard_dir"

The loss is rapidly decreasing after some iteration but stop decreasing. We get loss roughly 0.4

## Run an expriment for image's classification using Transfer-learning 

### I- Adapt the last layer to your classification's goal
Here we use the DenseNet model that allows classification through 10000 samples. 
We could adapt this classification's task to the MedNIST dataset by replacing the last layer with our classifier. 
The `model.classifier` layer of the `DenseNet-121` model classifies images through 6 classes, in the Training Plan, by adapting the num_classes value (can be done in through `model_args` argument). 

### Data augmentation
You could perform data augmentation through the preprocess part if you need. Here I show random flip, rotation and crops. 
You could do the preprocessing of images by doing only transforms.resize, transforms.to_tensor and transforms.normalize, as mentionned in the code below (commented lines). 

### I -1. Define Training plan experiment 

In [3]:
class MyTrainingPlan1(TorchTrainingPlan):

    def init_model(self, model_args):
       
        # Load the pre-trained DenseNet model, you have two ways to import your model
        
        model = models.densenet121(pretrained=True)
        
        
        # Remove the classification layer of DenseNet
        for param in model.features[:-1].parameters():
            param.requires_grad = False
            
        # add the classifier 
        num_classes = model_args['num_classes'] 
        num_ftrs = model.classifier.in_features
        model.classifier= nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
      
        return model

    def init_dependencies(self):
        return [
            "from torchvision import datasets, transforms, models",
            "import torch.optim as optim",
            "from torchvision.models import densenet121"
        ]


    def init_optimizer(self, optimizer_args):        
        return optim.Adam(self.model().parameters(), lr=optimizer_args["lr"])

    
    # training data
    
    def training_data(self):
        

        # Transform images and  do data augmentation 
        preprocess = transforms.Compose([
                transforms.Resize((224,224)),  
                #transforms.RandomHorizontalFlip(p=0.5),
                #transforms.RandomVerticalFlip(p=0.5),
                #transforms.RandomRotation(30),
                #transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
           ])
    
        train_data = datasets.ImageFolder(self.dataset_path,transform = preprocess)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=train_data, **train_kwargs)

    def training_step(self, data, target):
        output = self.model().forward(data)
        loss_func = nn.CrossEntropyLoss()
        loss   = loss_func(output, target)
        return loss




### Downloading the pretrained model's weights 
Here we download and save the model's weights through Torch.hub using the command below in a file 'pretrained_model.pt'

In [13]:
#model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True)
#torch.save(model.state_dict(), 'pretrained_model.pt')

Using cache found in /user/ebirgy/home/.cache/torch/hub/pytorch_vision_v0.10.0
The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=DenseNet121_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet121_Weights.DEFAULT` to get the most up-to-date weights.


In [4]:
training_args = {
    'loader_args': { 'batch_size': 32, }, 
    'optimizer_args': {'lr': 1e-3}, 
    'epochs': 5, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

model_args = {
    'num_classes': 6 # adapt this number to the number of classes in your dataset
}

In [5]:
tags =  ['#MEDNIST', '#dataset']

rounds = 1 # adjsut the number of rounds 

exp = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan1,
                 model_args=model_args,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage())

# testing section 
from fedbiomed.common.metrics import MetricTypes
exp.set_test_ratio(.1) 
exp.set_test_on_local_updates(True)
exp.set_test_metric(MetricTypes.ACCURACY)

exp.set_tensorboard(True)

2024-02-23 10:45:08,526 fedbiomed INFO - Starting researcher service...

2024-02-23 10:45:08,535 fedbiomed INFO - Waiting 3s for nodes to connect...

2024-02-23 10:45:09,640 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:45:11,549 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:45:11,554 fedbiomed INFO - Node selected for training -> config_mednist_1_sampled

2024-02-23 10:45:11,561 fedbiomed DEBUG - Model file has been saved: /home/ebirgy/development/fedbiomed_github/fedbiomed/var/experiments/Experiment_0004/model_705f2409-4190-4f68-a120-ce71ddedc961.py

Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_mode`` turned on.
The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=DenseNet121_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet121_Weights.DEFAULT` to get the most up-to-date weights.


2024-02-23 10:45:11,809 fedbiomed DEBUG - using native torch optimizer

2024-02-23 10:45:11,809 fedbiomed INFO - Removing tensorboard logs from previous experiment

2024-02-23 10:45:11,811 fedbiomed DEBUG - Experimentation training_args updated for `job`

2024-02-23 10:45:11,811 fedbiomed DEBUG - Experimentation training_args updated for `job`

2024-02-23 10:45:11,812 fedbiomed DEBUG - Experimentation training_args updated for `job`

True

### I - 2. Define the dataset for your experiment 

We propose to run this first experiment with only one Node (ie with  MedNIST_sampled_1 dataset, a sub-sampled dataset of 500 MedNIST images), because this first method is a transfer learning without training.

Here we show how to select one dataset among the connected datasets:

In [6]:

exp.set_nodes(['config_mednist_2_sampled'])
exp.set_tags(['#MEDNIST', '#dataset'])
td = exp.training_data().data()
td.pop('config_mednist_1_sampled')
exp.set_training_data(td)

print(exp.training_data().data())

2024-02-23 10:45:11,949 fedbiomed DEBUG - Experimentation nodes filter changed, you may need to update `training_data`

2024-02-23 10:45:11,954 fedbiomed DEBUG - Experimentation tags changed, you may need to update `training_data`

2024-02-23 10:45:11,957 fedbiomed DEBUG - Training data changed, you may need to update `node_selection_strategy`

2024-02-23 10:45:11,960 fedbiomed DEBUG - Training data changed, you may need to update `job`

2024-02-23 10:45:11,961 fedbiomed DEBUG - Training data changed, you may need to update `aggregator`

{}


In [30]:
exp.training_plan().import_model('pretrained_model.pt')





2024-02-23 10:41:27,248 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_2_sampled 
					 Round 1 Epoch: 4 | Iteration: 15/15 (100%) | Samples: 450/450
 					 Loss: [1m1.143529[0m 
					 ---------

2024-02-23 10:41:29,110 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: config_mednist_2_sampled 
					 Round 1 Epoch: 5 | Iteration: 1/15 (7%) | Samples: 32/480
 					 Loss: [1m0.105942[0m 
					 ---------

2024-02-23 10:41:32,636 fedbiomed DEBUG - Node: config_mednist_2_sampled polling for the tasks

### I - 3. Run your experiment 

In [7]:
exp.run()

2024-02-23 10:45:21,880 fedbiomed INFO - Sampled nodes in round 0 []

2024-02-23 10:45:21,891 fedbiomed INFO - Nodes that successfully reply in round 0 []


--------------------
Fed-BioMed researcher stopped due to exception:
FB401: aggregation crashes or returns an error. Aggregation aborted due to sum of the weights is equal to 0 {}. Sample sizes received from nodes might be corrupted.
--------------------


FedbiomedSilentTerminationError: 

2024-02-23 10:46:11,552 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:46:23,974 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:47:23,971 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:47:30,189 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:48:30,184 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:48:36,625 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:49:36,628 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:50:36,625 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-23 10:50:47,717 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

###### For example,  At the end of training experiment, I obtained

                      INFO - VALIDATION ON LOCAL UPDATES 
					 NODE_ID: NODE_41cd99c8-3571-4ab3-958e-6357ce31e91b 
					 Round 2 | Iteration: 1/1 (100%) | Samples: 100/100
 					 ACCURACY: 0.980000
					 -

As you can see, Accuracy has been increased in comparison to the first `Expermient`   

### I - 4. Save your model 
You could save your model to later use it in a new TrainingPlan 
This save allows to import the model including your layers's modification and weights values.

In [None]:
#save the model 
exp.training_plan().export_model('./training_plan1_densenet_MedNIST')

### I - 5. Results in tensorboard 

In [None]:
from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir "$tensorboard_dir"

## II - Partial fine-tuning: Use pretrained DenseNet and train specific layers with your data
You can set the second dataset with more images to run the second experiment that uses training steps. 

In this example, we run a second experiment with 1500 images (from both nodes).
The dataset is defined below, after TrainingPlan as previously shown.

You could also import the model you saved to perform your second TrainingPlan experiment (let's see below)



In [None]:
from fedbiomed.researcher.requests import Requests
req  = Requests()
req.list()

Here I freeze 3 layers since we have a bigger dataset than in the first part

In [None]:
from fedbiomed.common.training_plans import TorchTrainingPlan
class MyTrainingPlan2(TorchTrainingPlan):

    def init_model(self, model_args):

        # Load the pre-trained DenseNet model
        model = models.densenet121(pretrained=True)
       
        # For example, let's freeze layers of the last dense block
        for param in model.features[:-3].parameters():
            param.requires_grad = False

        # add the classifier 
        num_ftrs = model.classifier.in_features
        num_classes = model_args['num_classes'] 
        model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)       
            )
        
        return model

    def init_dependencies(self):
        return [
            "from torchvision import datasets, transforms, models",
            "import torch.optim as optim"
        ]


    def init_optimizer(self, optimizer_args):        
        return optim.Adam(self.model().parameters(), lr=optimizer_args["lr"])

    def training_data(self):
        
        # Custom torch Dataloader for MedNIST data and transform images and perform data augmentation 
       
        preprocess = transforms.Compose([
                transforms.Resize((224,224)),  
                #transforms.RandomHorizontalFlip(p=0.5),
                #transforms.RandomVerticalFlip(p=0.5),
                #transforms.RandomRotation(30),
                #transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
           ])
        train_data = datasets.ImageFolder(self.dataset_path,transform = preprocess)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=train_data, **train_kwargs)



    def training_step(self, data, target):
        output = self.model().forward(data)
        loss_func = nn.CrossEntropyLoss()
        loss   = loss_func(output, target)
        return loss




In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

training_args = {
    'loader_args': { 'batch_size': 32, }, 
    'optimizer_args': {'lr': 1e-4}, # You could decrease the learning rate
    'epochs': 5, # you can increase the epoch's number =10
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}
model_args={
    'num_classes': 6
}
tags =  ['#MEDNIST', '#dataset']
rounds = 2  # you can increase the rounds's number 

exp = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan2,
                 model_args=model_args,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage())

from fedbiomed.common.metrics import MetricTypes
exp.set_test_ratio(.1)
exp.set_test_on_local_updates(True)
exp.set_test_metric(MetricTypes.ACCURACY)

exp.set_tensorboard(True)
    

### II - 1. (Optional) Import a "custom model" or continue with the original DenseNet model of the TrainingPlan 

In [None]:
exp.training_plan().import_model('./training_plan1_densenet_MedNIST') 

### II - 2. Run your experiment 

In [None]:
exp.run()

For example,  At the end of training experiment, I obtained

                    fedbiomed INFO - VALIDATION ON LOCAL UPDATES 
					 NODE_ID: NODE_7842724a-cafa-49cc-862d-149288bbbb22 
					 Round 1 | Iteration: 1/1 (100%) | Samples: 100/100
 					 ACCURACY: 1.00000
					 ---------

In [None]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1]
for r in round_data.values():
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = r['node_id'],
        rtraining = r['timing']['rtime_training'],
        ptraining = r['timing']['ptime_training'],
        rtotal = r['timing']['rtime_total']))
print('\n')


### II -  3. Export your model 

In [None]:
#save model 
exp.training_plan().export_model('./training_plan2_densenet_MedNIST')

### II - 4. Display losses on Tensorboard

In [None]:
%reload_ext tensorboard

In [None]:
%tensorboard --logdir "$tensorboard_dir"

### II - 5. Save and Import your model and parameters 

You could import your first model from TrainingPlan1 instead of loading the original DenseNet.
You could also retrieve the model's features.

In [None]:
# save your model ( all layers model of te training experiment)
remote_model = exp.training_plan().model()  
torch.save(remote_model, './training_plan2_model')


In [None]:
# import your model 
model= torch.load('./training_plan2_model')
model

### II - 6. Save model's features, parameters 

In [None]:
model_features = exp.training_plan().export_model('./training_plan2_model')
model_features

In [None]:
# import your model's layers features
model_features_= torch.load('./training_plan2_model')
model_features_

In [None]:
# equivalent commands 
remote_model = exp.training_plan().model()
remote_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])