# Transfer-learning tutorial using DenseNet-121 pre-trained model:
# example on MedNIST dataset

## Introduction

This tutorial shows how to do 2d image classification example on MedNIST dataset using pretrained PyTorch model Densnet121 [] https://pytorch.org/vision/main/generated/torchvision.models.densenet121.html.

## Goal of this tutoriel

The goal of this tutorial is to provide an example of transfer learning methods with Fed-BioMed for medical images classification.

## About the model

The model used is Densenet-121 model(“Densely Connected Convolutional Networks”) pretrained on ImageNet dataset. The Pytorch pretrained Densenet121 is used https://pytorch.org/vision/main/generated/torchvision.models.densenet121.html to perform image classification on the MedNIST dataset. 
The goal of this Densenet121 model is to predict the class of the image modality given the medical image.



### About MedNIST

MedNIST provides an artificial 2d classification dataset created by gathering different medical imaging datasets from TCIA, the RSNA Bone Age Challenge, and the NIH Chest X-ray dataset. The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) under the Creative Commons CC BY-SA 4.0 license.

MedNIST dataset is downloaded from the resources provided by the project MONAI: https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz

The dataset MedNIST has 58954 images of size (3, 64, 64) distributed into 6 classes (10000 images per class except for BreastMRI class which has 8954 images). Classes are AbdomenCT, BreastMRI, CXR, ChestCT, Hand, HeadCT. It has the structure:

└── MedNIST/

    ├── AbdomenCT/

    └── BreastMRI/

    └── CXR/

    └── ChestCT/

    └── Hand/

    └── HeadCT/   
   

### Transfer-learning
Transfer learning is a machine learning technique where a model trained on one task is repurposed or adapted for a second related task. Transfer learning uses a pre-trained neural network on a large dataset, as Imagenet is used to train DenseNet model to perform classification of a wide diversity of images.

The objective is that the knowledge gained from learning one task can be useful for learning another task (as we do here, classification of medical images in 6 categories). This is particularly beneficial when the amount of labeled data for the target task is limited, as the pre-trained model has already learned useful features and representations from a large dataset.

Transfer learning is typically applied in one of two ways:

- Feature Extraction: In this approach, the pre-trained model is used as a fixed feature extractor. The earlier layers of the neural network, which capture general features and patterns, are frozen, and only the later layers are replaced or retrained for the new task. 

- Fine-tuning: In this approach, the pre-trained model is further trained or partially trained on the new task. This allows the model to adapt its learned representations to the specifics of the new task while retaining some of the knowledge gained from the original task.


In this example, we load on the node a sampled dataset ( 500 or 1000 images) of MedNIST to illustrate the effectiveness of the transfer learning. The sampled dataset is made with a random selection of images and return a sampled dataset with balanced classes, to avoid classification's bias.
We will test these two approches through two independant TrainingPlan experiments. 
To illustrate the effectiveness of these two method, we load 500 images for the first experiment and 1000 images for the second. Because the fine tunng method involves more layers's training, this method is better efficient for large datatsets. 

## Setup the node

- From the folder fedbiomed, execute the command ./scripts/fedbiomed_run node add

- Select option 3 (mednist) to add MedNIST to the node
- Confirm mednist tags ['#MEDNIST', '#dataset'] by hitting "y" and ENTER
- Select the folder where MedNIST is downloaded (It will be downloaded if it is not found in the selected path)
Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
- Enter the amount's sample you want to run in your experiment.

- Check that your data has been added by executing ./scripts/fedbiomed_run node list
- Start the node using ./scripts/fedbiomed_run node start. Wait until you get Starting task manager.


## Start Fed-BioMed Researcher

We are now ready to start the researcher environment with the command source ./scripts/fedbiomed_run researcher start
, and open the Jupyter notebook.

To make sure that MedNIST dataset is loaded in the node we can send a request to the network to list the available dataset in the node. The list command should output an entry for mednist data.


In [5]:
from fedbiomed.researcher.requests import Requests
req  = Requests()
req.list()

{}

## Import of librairies 

In [2]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms
from torchvision.models import densenet121
from torchvision import datasets, transforms, models
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage


## Run an expriment 

### 1. Load dataset or sampled dataset
- From the root directory of Fed-BioMed, run :  source ./scripts/fedbiomed_environment node in order to load the Node environment
- If you have already ran Mednist nodes before, clean remaining MedNIST nodes : run ./scripts/fedbiomed_run node delete or source ./scripts/fedbiomed_environment clean
- In this new environment, run the script python: python ./notebooks/transfer-learning/download_sample_of_mednist.py -n <number-of-nodes>, with <number-of-nodes> the number of Nodes you want to create( for more details about this script, please run notebooks/transfer-learning/download_sample_of_mednist.py --help)
- The script will ask for each Nodes created the number of samples you want for your dataset. Scripts will output configuration files for each of Nodes, with configured database.  
- Finally launch your Nodes (one by terminal) by running: ./scripts/fedbiomed_run node config  start config_mednist_<i>_sampled.ini start, where <i> corresponds to the number of Node created.  Wait until you get Starting task manager.

### 2. Launch the researcher 
- From the root directory of Fed-BioMed, run : ./scripts/fedbiomed_run researcher start

 

## Classification using Transfer-learning 

### Adapt the last layer to your classification's goal
Here we use the DenseNet model that allows classification through 1000 classes. 
We could adapt this classification's task to the MedNIST dataset by replacing the last layer with our classifier. 
The model.classifier classify images through 6 classes, by adapting the num_classes value. 

### Data augmentation
You could perform data augmentation through the preprocess part if you need. Here I show random flip, rotation and crops. 
You could do the preprocessing of images by doing only transforms.resize, transforms.to_tensor and transforms.normalize, as mentionned in the code below. 

In [3]:
class MyTrainingPlan1(TorchTrainingPlan):

    def init_model(self, model_args):
       
        # Load the pre-trained DenseNet model
        model = models.densenet121(pretrained=True)
        
        # Remove the classification layer of DenseNet
        for param in model.features[:-1].parameters():
            param.requires_grad = False
            
        # add the classifier 
        num_classes = model_args['num_classes'] 
        num_ftrs = model.classifier.in_features
        model.classifier= nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
      
        return model

    def init_dependencies(self):
        return [
            "from torchvision import datasets, transforms, models",
            "import torch.optim as optim",
            "from torchvision.models import densenet121"
        ]


    def init_optimizer(self, optimizer_args):        
        return optim.Adam(self.model().parameters(), lr=optimizer_args["lr"])

    
    # training data
    
    def training_data(self):
        
        # Custom torch Dataloader for MedNIST data
        print("dataset path",self.dataset_path)

        # Transform images and  do data augmentation 
        preprocess = transforms.Compose([
                transforms.Resize((224,224)),  
                #transforms.RandomHorizontalFlip(p=0.5),
                #transforms.RandomVerticalFlip(p=0.5),
                #transforms.RandomRotation(30),
                #transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
           ])
    
        train_data = datasets.ImageFolder(self.dataset_path,transform = preprocess)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=train_data, **train_kwargs)

    def training_step(self, data, target):
        output = self.model().forward(data)
        loss_func = nn.CrossEntropyLoss()
        loss   = loss_func(output, target)
        return loss




In [4]:
training_args = {
    'loader_args': { 'batch_size': 32, }, 
    'optimizer_args': {'lr': 1e-3}, 
    'epochs': 2, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

model_args = {
    'num_classes': 6 # adapt this number to the number of classes in your dataset
}

In [5]:
tags =  ['#MEDNIST', '#dataset']

rounds = 2 # adjsut the number of rounds 

exp = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan1,
                 model_args=model_args,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage())

# testing section 
from fedbiomed.common.metrics import MetricTypes
exp.set_test_ratio(.1) 
exp.set_test_on_local_updates(True)
exp.set_test_metric(MetricTypes.ACCURACY)

exp.set_tensorboard(True)

2024-02-05 10:31:40,858 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:31:40,861 fedbiomed INFO - No available dataset has found in nodes with tags: ['#MEDNIST', '#dataset']

2024-02-05 10:31:40,866 fedbiomed DEBUG - Model file has been saved: /home/ebirgy/development/fedbiomed_github/fedbiomed/var/experiments/Experiment_0081/model_4915028e-f101-4b40-bcf4-1ab87310931d.py

Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_mode`` turned on.
The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=DenseNet121_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet121_Weights.DEFAULT` to get the most up-to-date weights.


2024-02-05 10:31:41,128 fedbiomed DEBUG - using native torch optimizer

2024-02-05 10:31:41,129 fedbiomed DEBUG - Experimentation training_args updated for `job`

2024-02-05 10:31:41,129 fedbiomed DEBUG - Experimentation training_args updated for `job`

2024-02-05 10:31:41,130 fedbiomed DEBUG - Experimentation training_args updated for `job`

True

In [6]:
exp.run()

2024-02-05 10:31:49,929 fedbiomed INFO - Sampled nodes in round 0 []

2024-02-05 10:31:49,938 fedbiomed INFO - Nodes that successfully reply in round 0 []


--------------------
Fed-BioMed researcher stopped due to exception:
FB401: aggregation crashes or returns an error. Aggregation aborted due to sum of the weights is equal to 0 {}. Sample sizes received from nodes might be corrupted.
--------------------


FedbiomedSilentTerminationError: 

2024-02-05 10:32:23,713 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:33:23,705 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:33:34,487 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:34:34,486 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:34:44,891 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:35:44,891 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:35:49,194 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:36:49,189 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:36:56,921 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:37:56,917 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:38:12,417 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:39:12,418 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:39:25,842 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:40:25,836 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks

2024-02-05 10:40:36,647 fedbiomed DEBUG - Node: config_mednist_1_sampled polling for the tasks



###### For example,  At the end of training experiment, I obtained

                      INFO - VALIDATION ON LOCAL UPDATES 
					 NODE_ID: NODE_41cd99c8-3571-4ab3-958e-6357ce31e91b 
					 Round 2 | Iteration: 1/1 (100%) | Samples: 100/100
 					 ACCURACY: 0.960000
					 -

In [7]:
#save model 
exp.training_plan().export_model('./training_plan1_densenet_MedNIST')

In [None]:
from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir "$tensorboard_dir"

## Upload more data and train the top layers 

You can load more data on a new node for the second experiment and train top layers of the denseNet model.
To cange the amount of data, you have to stop the previous node in the console by tapping CTL+C.

In this example, I run a second experiment with 1000 images.
Run an other node with 1000 images (as previously described above)


In [None]:
from fedbiomed.researcher.requests import Requests
req  = Requests()
req.list()

## Partial fine-tuning: Use pretrained DenseNet and train top layers with your data

In [None]:
class MyTrainingPlan2(TorchTrainingPlan):

    def init_model(self, model_args):

        # Load the pre-trained DenseNet model
        model = models.densenet121(pretrained=True)
        
        # For example, let's freeze layers of the last dense block
        for param in model.features[:-3].parameters():
            param.requires_grad = False

        # add the classifier 
        num_ftrs = model.classifier.in_features
        num_classes = model_args['num_classes'] 
        model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)       
            )
        
        return model

    def init_dependencies(self):
        return [
            "from torchvision import datasets, transforms, models",
            "import torch.optim as optim"
        ]


    def init_optimizer(self, optimizer_args):        
        return optim.Adam(self.model().parameters(), lr=optimizer_args["lr"])

    def training_data(self):
        
        # Custom torch Dataloader for MedNIST data and transform images and perform data augmentation 
        print("dataset path",self.dataset_path)
        preprocess = transforms.Compose([
                transforms.Resize((224,224)),  
                #transforms.RandomHorizontalFlip(p=0.5),
                #transforms.RandomVerticalFlip(p=0.5),
                #transforms.RandomRotation(30),
                #transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
           ])
        train_data = datasets.ImageFolder(self.dataset_path,transform = preprocess)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=train_data, **train_kwargs)



    def training_step(self, data, target):
        output = self.model().forward(data)
        loss_func = nn.CrossEntropyLoss()
        loss   = loss_func(output, target)
        return loss




In [None]:
training_args = {
    'loader_args': { 'batch_size': 32, }, 
    'optimizer_args': {'lr': 1e-4}, # You could decrease the learning rate
    'epochs': 1, # you can increase the epoch's number =10
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}
model_args={
    'num_classes': 6
}
tags =  ['#MEDNIST', '#dataset']
rounds = 1  # you can increase the rounds's number 

exp = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan2,
                 model_args=model_args,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage())

from fedbiomed.common.metrics import MetricTypes
exp.set_test_ratio(.1)
exp.set_test_on_local_updates(True)
exp.set_test_metric(MetricTypes.ACCURACY)

exp.set_tensorboard(True)
    

In [None]:
exp.run()

For example,  At the end of training experiment, I obtained

                    fedbiomed INFO - VALIDATION ON LOCAL UPDATES 
					 NODE_ID: NODE_7842724a-cafa-49cc-862d-149288bbbb22 
					 Round 1 | Iteration: 1/1 (100%) | Samples: 100/100
 					 ACCURACY: 0.990000 
					 ---------

In [None]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1]
for r in round_data.values():
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = r['node_id'],
        rtraining = r['timing']['rtime_training'],
        ptraining = r['timing']['ptime_training'],
        rtotal = r['timing']['rtime_total']))
print('\n')


In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir "$tensorboard_dir"

## Save and export your model 

You can save the TrainingPlan experiment and the fine-tune model by executing the command below :

In [None]:
exp.training_plan().export_model('./training_plan2_densenet_MedNIST')

In [None]:
# save your model ( all layers model of te training experiment)
remote_model = exp.training_plan().model()
torch.save(remote_model, './training_plan2_model')

In [None]:
#from torchvision import models

#torch.save(models.densenet121(pretrained=True).state_dict(), './model_training_plan_2')

## Import your model and parameters 

In [None]:
your_model = torch.load('./training_plan2_model')

In [None]:
# load your parameters (tensors's values of your tuned-model)
tuned_model= torch.load('./training_plan2_densenet_MedNIST')

In [None]:
# In a new TrainingPlan experiment you could import your tuned-model 
exp.training_plan().import_model('./training_plan2_densenet_MedNIST')

### This part needs confirmation, tests,( and agreements to load parameters ? ) 

In [None]:
#remote_model = remote_experiment.training_plan().model()
tuned_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

In [None]:
tuned_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])
#tuned_model.load_state_dict(remote_experiment.aggregated_params()[rounds - 1]['params'])

In [None]:
tuned_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])