# AI in Medicine I - Practical 3: Transfer Learning

Generating good labels for medical datasets is an expensive and time consuming task, especially on tasks such as segmentation where expert radiologists are often required.
Often we are faced with datasets that have scarce labels or none at all and must rely on selfsupervised pretraining methods to increase our performance.
We will continue to use the brain MRI dataset from the previous practicals.
The Jupyter Notebook provided contains some preliminary code you can use and some function prototypes that you are expected to fill in.
The deliverables for the submission an archive containing the code provided completed as well as a short report explaining your strategies and choices for each task in this practical.


**Make sure to select the correct runtime when working in Google Colab (GPU)**

### Read the text descriptions and code cells carefully and look out for the cells marked with 'TASK' and 'ADD YOUR CODE HERE'.

In [None]:
# Only run this cell when in Google Colab
! git init
! git remote add origin https://github.com/compai-lab/aim-practical-3-transfer-learning
! git fetch
! git checkout -t origin/main

## Downloading the Data

In [None]:
! wget https://www.dropbox.com/s/w9njau9t6rrheel/brainage-data.zip
! unzip brainage-data.zip
! wget https://www.dropbox.com/s/f5mt8p9pkszff3x/brainage-testdata.zip
! unzip brainage-testdata.zip

In [None]:
! pip install monai

## Imports

In [None]:
from argparse import Namespace

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor


from monai.networks.nets import UNet
from monai.losses import DiceCELoss
from monai.transforms import AsDiscrete
from monai.metrics import compute_dice


from data_utils import get_image_dataloaders
from utils import AvgMeter, seed_everything
%load_ext tensorboard
%load_ext autoreload
%autoreload 2


## Getting started and familiarise ourselves with the data

We provide the data of 652 subjects from which we use 500 for training, 47 for validation, and the rest for testing your final model.
The following cells provide helper functions to load the data and provide an overview and visualization of the statistics over the total population of the 652 subjects.

In [None]:
train_df = pd.read_csv('./data/brain_age/meta/meta_data_regression_train.csv')
val_df = pd.read_csv('./data/brain_age/meta/meta_data_segmentation_train.csv')
test_df = pd.read_csv('./data/brain_age/meta/meta_data_regression_test.csv')
train_df['subject_id']
id_overlap = pd.merge(train_df, val_df, on='subject_id', how='inner')
assert len(id_overlap)==0
id_overlap = pd.merge(train_df, test_df, on='subject_id', how='inner')
assert len(id_overlap)==0
id_overlap = pd.merge(val_df, test_df, on='subject_id', how='inner')
assert len(id_overlap)==0

## Segmentation

We again wish to segment our brains using deep neural network. The following code is a basic example of how to do so.

### Full Dataset Results

In [None]:
def train_segmentation(config, model, optimizer, train_loader, val_loader):
    model.train()
    step = 0
    checks = 0
    best_val_loss = float('Inf')
    avg_loss = AvgMeter()
    avg_dice = AvgMeter()
    
    criterion = DiceCELoss(include_background=False, softmax=True)
    postprocess = AsDiscrete(argmax=True, to_onehot=4)


    while True:
        for x, y in train_loader:
            x = x.to(config.device)
            y = y.to(config.device)

            # Training step
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()

            pred_binarized = []
            y_binarized = []
            for i in range(pred.shape[0]):
                pred_binarized.append(postprocess(pred[i]))
                y_binarized.append(postprocess(y[i]))
            pred_binarized = torch.stack(pred_binarized)
            y_binarized = torch.stack(y_binarized)

            dice = compute_dice(pred_binarized, y_binarized, include_background=False).mean()

            avg_loss.add(loss.detach().item())
            avg_dice.add(dice.detach().item())

            # Increment step
            step += 1

            if step % config.log_freq == 0 and not step % config.val_freq == 0:
                train_loss = avg_loss.compute()
                train_dice = avg_dice.compute()

            # Validate and log at validation frequency
            if step % config.val_freq == 0:
                # Reset avg_loss
                train_loss = avg_loss.compute()
                train_dice = avg_dice.compute()
                avg_loss = AvgMeter()
                avg_dice = AvgMeter()

                # Get validation results
                val_results = validate_segmentation(
                    model,
                    val_loader,
                    config,
                    criterion,
                    postprocess
                )

                # Print current performance
                print(f"Finished step {step} of {config.num_steps}. "
                      f"Train loss: {train_loss} - "
                      f"Train Dice: {train_dice} - "
                      f"val loss: {val_results['val/loss']:.4f} - "
                      f"val Dice: {val_results['val/Dice']:.4f} - "
                      f"val Dice_CSF: {val_results['val/Dice_CSF']:.4f} - "
                      f"val Dice_WM: {val_results['val/Dice_WM']:.4f} - "
                      f"val Dice_GM: {val_results['val/Dice_GM']:.4f} - ")

                # Check if the validation loss has stopped increasing
                ### ADD YOUR CODE HERE ###

                ### END ###
            
            if step >= config.num_steps:
                print(f'\nFinished training after {step} steps\n')
                return model, step


def validate_segmentation(model, val_loader, config, criterion, postprocess):
    model.eval()
    avg_val_loss = AvgMeter()
    avg_val_dice = AvgMeter()
    avg_val_dice_1 = AvgMeter()
    avg_val_dice_2 = AvgMeter()
    avg_val_dice_3 = AvgMeter()
    for x, y in val_loader:
        x = x.to(config.device)
        y = y.to(config.device)

        with torch.no_grad():
            pred = model(x)    
        loss = criterion(pred, y)
        
        pred_binarized = []
        y_binarized = []
        for i in range(pred.shape[0]):
            pred_binarized.append(postprocess(pred[i]))
            y_binarized.append(postprocess(y[i]))
        pred_binarized = torch.stack(pred_binarized)
        y_binarized = torch.stack(y_binarized)

        dice = compute_dice(pred_binarized, y_binarized, include_background=False)
        mean_dice = dice.mean()
        mean_dice_per_class = dice.mean(dim=0)
        avg_val_loss.add(loss.item())
        avg_val_dice.add(mean_dice.item())
        avg_val_dice_1.add(mean_dice_per_class[0].item())
        avg_val_dice_2.add(mean_dice_per_class[1].item())
        avg_val_dice_3.add(mean_dice_per_class[2].item())

        
    model.train()
    return {
        'val/loss': avg_val_loss.compute(),
        'val/Dice': avg_val_dice.compute(),
        'val/Dice_CSF': avg_val_dice_1.compute(),
        'val/Dice_WM': avg_val_dice_2.compute(),
        'val/Dice_GM': avg_val_dice_3.compute()
    }


In [None]:
# Lets set some basic hyperparameters
config = Namespace()
config.img_size = 96
config.batch_size = 16
config.num_workers = 0

config.log_dir = './logs'
config.val_freq = 50
config.log_freq = 10

config.seed = 0
config.device = 'cuda'
config.autoencoder = False

config.lr = 1e-3
config.betas = (0.9, 0.999)

config.num_steps = 1500
config.patience = 5

seed_everything(config.seed)


In [None]:
import pandas as pd
all_train_df = pd.read_csv('./data/brain_age/meta/meta_data_regression_train.csv')
low_data_train_df = all_train_df.sample(n=200, random_state=12)
low_data_train_df.to_csv('./data/brain_age/meta/meta_data_regression_train_lowdata_200.csv', index=False)

In [None]:
# Load data
dataloaders_fulldata_segmentations = get_image_dataloaders(
    img_size=config.img_size,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    segmentations = True,
    low_data = '200'
)

#### TASK: MONAI UNET
For todays exercise we will be using a library called MONAI which has some very useful tools for medical imaging. Our task is segmentation and so we will be using a classic UNet. Think about how many channels, strides and residual unets you want to include. Remember that for smaller datasets, less parameters is often an advantage. Play around with different settings and see how the performance changes.

In [None]:
# Init model
model_unet = UNet(
  spatial_dims=3,
  out_channels=4,
  ### ADD YOUR CODE HERE ###

  ### END ###
).to(config.device)


# Init optimizers
optimizer = torch.optim.AdamW(
    model_unet.parameters(),
    lr=config.lr,
    betas=config.betas
)

model_unet, step = train_segmentation(
    config=config,
    model=model_unet,
    optimizer=optimizer,
    train_loader=dataloaders_fulldata_segmentations['train'],
    val_loader=dataloaders_fulldata_segmentations['val']
)

In [None]:
# Test
test_results = validate_segmentation(model_unet, dataloaders_fulldata_segmentations['test'], config, DiceCELoss(include_background=False, softmax=True), AsDiscrete(argmax=True, to_onehot=4))
print(f'Test loss: {test_results["val/loss"]:.4f}')
print(f'Test Mean Dice: {test_results["val/Dice"]:.4f}')
print(f'Test Dice_CSF: {test_results["val/Dice_CSF"]:.4f}')
print(f'Test Dice_WM: {test_results["val/Dice_WM"]:.4f}')
print(f'Test Dice_GM: {test_results["val/Dice_GM"]:.4f}')

### TASK: Low Label Simulation

Often we don't have labels for all of our samples, or have multiple datasets, where we might have labels for one dataset but not the other. 
Especially in medicine, where expert labels can be expensive to generate, we have to combine datasets to have enough data for training. 

Here we will simulate the scenario that our radiologists don't like us and have decided to only segment three brains.

Retrain the model and see what happens to our test results.

In [None]:
import pandas as pd
all_train_df = pd.read_csv('./data/brain_age/meta/meta_data_regression_train.csv')
low_data_train_df = all_train_df.sample(n=3, random_state=12)
low_data_train_df.to_csv('./data/brain_age/meta/meta_data_regression_train_lowdata_3.csv', index=False)

In [None]:
dataloaders_lowdata_segmentations = get_image_dataloaders(
    img_size=config.img_size,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    low_data= '3',
    segmentations = True,
    train_only = True
)

In [None]:
model_unet_lowdata = UNet(
  spatial_dims=3,
  out_channels=4,
### ADD YOUR CODE HERE ###

### END ###
).to(config.device)


# Init optimizers
optimizer = torch.optim.AdamW(
    model_unet_lowdata.parameters(),
    lr=config.lr,
    betas=config.betas
)

config.patience = 5
config.num_steps = 3000
model_unet_lowdata, step = train_segmentation(
    config=config,
    model=model_unet_lowdata,
    optimizer=optimizer,
    train_loader=dataloaders_lowdata_segmentations['train'],
    val_loader=dataloaders_fulldata_segmentations['val']
)

In [None]:
# Test
test_results = validate_segmentation(model_unet_lowdata, dataloaders_fulldata_segmentations['test'], config, DiceCELoss(include_background=False, softmax=True), AsDiscrete(argmax=True, to_onehot=4))
print(f'Test loss: {test_results["val/loss"]:.4f}')
print(f'Test Mean Dice: {test_results["val/Dice"]:.4f}')
print(f'Test Dice_CSF: {test_results["val/Dice_CSF"]:.4f}')
print(f'Test Dice_WM: {test_results["val/Dice_WM"]:.4f}')
print(f'Test Dice_GM: {test_results["val/Dice_GM"]:.4f}')

### TASK: What happens to our training dynamics and duration now that we have less data? 

Maybe setting a fixed amount of steps is not the best idea...
Implement some form of early stopping in the training loop that stops training after the validation loss does not improve after N checks.
Play around with this patience parameter. How does it affect test accuracy? Why?

Retrain the model

### TASK: Autoencoder

We can't seem to achieve our previous performance because we don't have enough labeled samples. We still have the rest of the data, we just don't have any labels. Maybe we can improve our performance by doing some initial unsupervised learning over the full, unlabeled dataset before finetuning on our three labeled samples.

In [None]:
def train_autoencoder(config, model, optimizer, train_loader, val_loader):
    model.train()
    step = 0
    avg_loss = AvgMeter()

    while True:
        for x, y in train_loader:
            x = x.to(config.device)
            y = y.to(config.device)
            

            # Training step
            optimizer.zero_grad()
            pred = model(x)
            loss = torch.pow((pred - x), 2).mean()
            loss.backward()
            optimizer.step()

            avg_loss.add(loss.detach().item())

            # Increment step
            step += 1

            # Validate and log at validation frequency
            if step % config.val_freq == 0:
                # Reset avg_loss
                train_loss = avg_loss.compute()
                avg_loss = AvgMeter()

                # Get validation results
                val_results = validate_autoencoder(
                    model,
                    val_loader,
                    config
                )

                # Print current performance
                print(f"Finished step {step} of {config.num_steps}. "
                      f"Train loss: {train_loss} - "
                      f"val loss: {val_results['val/loss']:.4f} - "
                      f"val MAE: {val_results['val/MAE']:.4f}")

            if step >= config.num_steps:
                print(f'\nFinished training after {step} steps\n')
                return model, step


def validate_autoencoder(model, val_loader, config, show_plot=False):
    model.eval()
    avg_val_loss = AvgMeter()
    preds = []
    targets = []
    for x, y in val_loader:
        x = x.to(config.device)
        y = y.to(config.device)

        with torch.no_grad():
            pred = model(x)
        loss = torch.pow((pred - x), 2).mean()
        avg_val_loss.add(loss.item())
        preds.append(pred.cpu())
        targets.append(x.cpu())

    preds = torch.cat(preds)
    targets = torch.cat(targets)
    mae = mean_absolute_error_image(preds, targets)
        
    model.train()
    return {
        'val/loss': avg_val_loss.compute(),
        'val/MAE': mae,
    }


def mean_absolute_error_image(preds: Tensor, targets: Tensor) -> float:
    """Compute the mean absolute error between predictions and targets"""
    return (preds - targets).abs().mean().item()


In [None]:
# Init model
# HINT: If you aren´t getting the performance you expect, try to change the number of channels and the number of res units
model_ae = UNet(
  ### ADD YOUR CODE HERE ###

  ### END ###
).to(config.device)

# Init optimizers
optimizer = torch.optim.AdamW(
    model_ae.parameters(),
    lr=config.lr,
    betas=config.betas
)

config.num_steps = 1500
model_ae, step = train_autoencoder(
    config=config,
    model=model_ae,
    optimizer=optimizer,
    train_loader=dataloaders_fulldata_segmentations['train'],
    val_loader=dataloaders_fulldata_segmentations['val']
)

#### TASK: Visualize Results

Plot both the original brains and the reconstructed brains above one another to see how good the autoencoder performed

In [None]:
orig_images = next(iter(dataloaders_fulldata_segmentations['val']))[0].to(config.device)

f, axarr = plt.subplots(2, 3)
orig_image = orig_images[0, 0].cpu().numpy()
H, W, D = orig_image.shape
axarr[0][0].imshow(orig_image[H // 2, :, :], cmap='gray')
axarr[0][1].imshow(orig_image[:, W // 2, :], cmap='gray')
axarr[0][2].imshow(orig_image[:, :, D // 2], cmap='gray')
### ADD YOUR CODE HERE ###

### END ###
plt.show()

Looks pretty good! 

#### TASK: Transfer Learning
Now lets see if its learned any useful features that can help us in our segmentation task

In [None]:
# Init model
model_ae_segmenter = UNet(
  ### ADD YOUR CODE HERE ###
  
  ### END ###
).to(config.device)

# Init optimizers
optimizer = torch.optim.AdamW(
    model_ae_segmenter.parameters(),
    lr=config.lr,
    betas=config.betas
)

# Load the weights from the autoencoder using the state_dict
# HINT: Remove the weights from the decoder as we have a different number of channels now and a different task. 
# HINT: Take a look at the arguments if you get errors that might be ok to ignore
### ADD YOUR CODE HERE ###

### END ###

# Sometimes after pretraining in very low data regimes it can help to freeze the encoder, to preserve the features learned over the large dataset.
# Write some code to freeze only the encoder here
### ADD YOUR CODE HERE ###

### END ###


In [None]:
# Train
model_ae_segmenter, step = train_segmentation(
    config=config,
    model=model_ae_segmenter,
    optimizer=optimizer,
    train_loader=dataloaders_lowdata_segmentations['train'],
    val_loader=dataloaders_fulldata_segmentations['val']
)

In [None]:
# Test
test_results = validate_segmentation(model_ae_segmenter, dataloaders_fulldata_segmentations['test'], config, DiceCELoss(include_background=False, softmax=True), AsDiscrete(argmax=True, to_onehot=4))
print(f'Test loss: {test_results["val/loss"]:.4f}')
print(f'Test Dice: {test_results["val/Dice"]:.4f}')
print(f'Test Dice_CSF: {test_results["val/Dice_CSF"]:.4f}')
print(f'Test Dice_WM: {test_results["val/Dice_WM"]:.4f}')
print(f'Test Dice_GM: {test_results["val/Dice_GM"]:.4f}')

Nice! It looks like we were able to boost our mean Dice score by 2-3 points.

#### TASK: Frozen vs Unfrozen Weights
Retrain the model above, this time with fully trainable weights. How does our performance change? Why do you think this is?

#### TASK: Autoencoder Performance VS Downstream Task
 
We can definitely improve our autoencoder reconstruction performance by training for longer. Maybe this isn't ideal for our segmentation task, but we can try it out. Try training for different number of steps and see how the reconstruction performance changes. What is the best number of steps?

# Bonus


## MedicalNet: Using Pretrained Weights from the Internet

Often times it helps to use weights trained by other people on larger datasets. In 3D Medical Imaging for example, MedicalNet (https://github.com/Tencent/MedicalNet) is a collection of 3D ResNets that have been trained on 23 segmentation datasets. Although there are other organs included, see how well the learned weights transfer to our task. Download and and use MedicalNet for our segmentation task. Is this better than selfsupervised training over dataset?