<a href="https://colab.research.google.com/github/katek28/Deep-Learning-projects/blob/main/Deep_learning_for_X_ray_Image_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AMV Assignment: Deep learning for X-ray Image Segmentation





---


Welcome to this notebook on image segmentation! In this assignment, we will be experimenting with deep learning models for segmentation. We will develop a model to segment the lungs on thorax xrays, using state-of-the-art models and a public xray dataset.

# Getting started and setting things up
## Importing and installing modules
In this notebook we will use several custom python modules, that have to be installed before we can use them. This is done in the code cell below. Please execute it and wait a bit for the installation to complete (should take no more than a couple of minutes). 


In [None]:
#@title
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import numpy as np
import random
import cv2
import copy
from PIL import Image
from glob import glob

from sklearn.model_selection import train_test_split

from ipywidgets import interact
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

!pip install segmentation-models-pytorch
!pip install pytorch-lightning

import segmentation_models_pytorch as smp
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
#from pytorch_lightning.callbacks.lr_logger import LearningRateLogger

%load_ext tensorboard

print('Module installations and imports completed successfully!')

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.2.0-py3-none-any.whl (87 kB)
[?25l[K     |███▊                            | 10 kB 19.1 MB/s eta 0:00:01[K     |███████▌                        | 20 kB 22.6 MB/s eta 0:00:01[K     |███████████▏                    | 30 kB 25.4 MB/s eta 0:00:01[K     |███████████████                 | 40 kB 27.3 MB/s eta 0:00:01[K     |██████████████████▊             | 51 kB 28.0 MB/s eta 0:00:01[K     |██████████████████████▍         | 61 kB 28.8 MB/s eta 0:00:01[K     |██████████████████████████▏     | 71 kB 26.9 MB/s eta 0:00:01[K     |██████████████████████████████  | 81 kB 28.2 MB/s eta 0:00:01[K     |████████████████████████████████| 87 kB 5.3 MB/s 
[?25hCollecting efficientnet-pytorch==0.6.3
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
Collecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 38.4 MB/s 
[?25hCollecting

### Mounting google drive to the notebook
This Colab notebook runs on a virtual machine (VM), hosted in one of Google's datacenters. By itself, it does not have any data associated with it (other than the default required to run CoLab). In order to do our segmentation project, we have to give it access to the dataset.

The easiest way to do this is by mounting your google drive on the VM. That way, you will give the notebook access to the data stored in your drive, such that it can be used for training and evaluating the models.

The code cell below starts the authentication procedure to mount your google drive. 




In [None]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


The data on your google drive is now available to the notebook in the folder '/gdrive'. You can use the file explorer in the menu bar on the left or right side of the screen (click on the folder icon) to explore the file structure. 



## Definitions


### Defining the dataset

While the notebook now has access to google drive, we still need a way to tell it how to handle the images. This is done using a 'DataLoader', which is a class that is used during training, validation and testing to prepare the images to feed to the model. It is defined in the cell below.

Apart from loading the images, it also handles splitting of the dataset into a train / validation / test parts. 

In [None]:
#@title
class JSRT_SCR(Dataset):
    '''
    Dataset Class for JSRT_SCR Thorax Radiograph Anatomy Semantic Segmentation dataset
    Dataset link - http://db.jsrt.or.jp/eng.php
    There is 1 class in the given labels.
    The `get_filenames` function retrieves the filenames of all images in the given `path` and
    saves the absolute path in a list.
    In the `get_item` function, images and masks are resized to the given `img_size`, 
    given `transform` (if any) are applied to the image only
    (mask does not usually require transforms, but they can be implemented in a similar way).
    '''

    def __init__(self, root_path, split, img_size=(512, 512), transform=None):
        self.img_size = img_size
        #labels = ['heart', 'left lung', 'left clavicle', 'right lung', 'right clavicle']
        labels = ['lung']
        self.class_map = dict(zip(labels, range(len(labels))))
        self.transform = transform
        self.split = split
        self.root = root_path

        self.img_path = os.path.join(self.root, 'images')
        self.mask_path = os.path.join(self.root, 'masks')
        self.img_list = self.get_filenames(self.img_path)
        self.mask_list = self.get_filenames(self.mask_path)
        self.img_list.sort()
        self.mask_list.sort()
        self.image_mask_pairs = list(zip(self.img_list, self.mask_list))

        # Split between train, valid and test set
        trainingset, test = train_test_split(self.image_mask_pairs, test_size=0.35, random_state=42)
        train, validation = train_test_split(trainingset, test_size=int(0.5*len(test)), random_state=42)

        datasets = {'train': train, 'valid': validation, 'test': test}
        self.img_list = [x[0] for x in datasets[self.split]]
        self.mask_list = [x[1] for x in datasets[self.split]]
   

    def __len__(self):
        return(len(self.img_list))

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx])
        img = img.resize(self.img_size)
        img = np.array(img)

        mask = Image.open(self.mask_list[idx]).convert('L')
        mask = mask.resize(self.img_size)
        mask = np.array(mask)
        mask[mask!=0] = 1

        if self.transform:
            img = self.transform(img)

        return img, mask

    def give_data(self, idx):
        return self.__getitem__(idx)

    def get_filenames(self, path):
        '''
        Returns a list of absolute paths to images inside given `path`
        '''
        files_list = list()
        for filename in os.listdir(path):
            files_list.append(os.path.join(path, filename))
        return files_list

print('DataLoader defined.')

DataLoader defined.


#### Inspecting the data
A very important aspect of deep learning is making sure that your data is in a proper state before you pass it into your model. If something's wrong with the data, you will never be able to train a good model on it. Therefore, it is really important to visually inspect your data before you do anything with it.



In [None]:
# You can change the value of 'dataset' to inspect the different splits
# Options are: 'train', 'valid' or 'test'
dataset = 'test'

# You have to set the value of 'root_path' such that it points to the
# folder where the images and the masks are located. 
# root_path = 'JSRT_SCR dataset'
root_path = '/gdrive/My Drive/JSRT_SCR dataset'


dsc = JSRT_SCR(root_path, split=dataset, transform=None)
@interact
def plot_image_and_mask(index=(0, dsc.__len__()-1, 1)):
    img, mask = dsc.give_data(index)
    image = np.dstack([img]*3)
    image_masked = copy.deepcopy(image)

    image_masked[:,:,0][mask!=0] = image_masked[:,:,0][mask!=0]*1
    image_masked[:,:,1][mask!=0] = image_masked[:,:,1][mask!=0]*0
    image_masked[:,:,2][mask!=0] = image_masked[:,:,2][mask!=0]*0

    display_image = np.hstack((image, image_masked))
    fig = plt.figure(figsize=(8,4))
    plt.imshow(display_image, cmap='bone')
    plt.axis('off')

interactive(children=(IntSlider(value=43, description='index', max=86), Output()), _dom_classes=('widget-inter…

### Defining the loss functions and metrics

For this segmentation problem we will use the dice coefficient as a metric. It measures the overlap between two areas; in our case these are the labels and the predictions of the model. For perfect overlap, the dice coefficient is equal to 1. If there is no overlap, the dice is 0.

We will be experimenting with three different loss functions: The dice coefficient, the binary cross entropy and the combo-loss, which is just defined as the sum of the bce loss and the dice loss

The Dice loss and the Combo loss are defined below. The bce is one of the default loss functions available in pytorch, so we don't have to define it manually here. 

In [None]:
def DiceMetric(inputs, targets, smooth=1): 
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()                            
    dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  

    return dice

def DiceLoss(inputs, targets, smooth=1):
    return 1 - DiceMetric(inputs, targets, smooth)

def ComboLoss(inputs, targets, smooth=1):
    dice_contribution = DiceLoss(inputs, targets, smooth)
    bce_contribution = F.binary_cross_entropy(inputs, targets)
    return bce_contribution + dice_contribution

print('Loss functions defined.')

Loss functions defined.


### Defining the model

In the following code cell the model itself is defined. Model definition is still a bit complex because it requires not only the details of the model itself, but also procedures for calculating the loss and metrics during training, validation and testing. Furthermore, it also holds the configuration of the optimizers and hyperparameters used during training (e.g. learning rate, number of epochs, batch size, etc.).

In [None]:
#@title
class SegModel(pl.LightningModule):
    '''
    Semantic Segmentation Module
    This is a basic semantic segmentation module implemented with Pytorch Lightning.
    It is specific to SCR dataset i.e. dataloaders are for SCR radiograph dataset
    and Normalize transform uses the mean and standard deviation of this dataset.
    It uses the ResNet18 model as an example.
    Adam optimizer is used.
    '''

    def __init__(self, hparams):
        super().__init__()
#        self.hparam = hparams
        self.root_path = hparams['root']
        self.batch_size = hparams['batch_size']
        self.epochs = hparams['epochs']
        self.learning_rate = hparams['lr']
        self.scheduler = hparams['lr_scheduler']
        self.loss_function = hparams['loss_function']

        decoder_channels = [256, 128, 64, 32, 16]
        self.net = smp.Unet(hparams['model_backbone'],
                            encoder_depth=hparams['encoder_depth'], 
                            encoder_weights = None,
                            classes=1, 
                            in_channels=1, 
                            activation='sigmoid', 
                            decoder_channels=decoder_channels[:hparams['encoder_depth']]) 
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            ])
        self.trainset = JSRT_SCR(self.root_path, split='train', transform=self.transform)
        self.validset = JSRT_SCR(self.root_path, split='valid', transform=self.transform)
        self.testset = JSRT_SCR(self.root_path, split='test', transform=self.transform)

        self.ntest = self.testset.__len__()
        self.nvalid = self.validset.__len__()
        self.ntrain = self.trainset.__len__()
        
        self.save_hyperparameters()
    
    def on_fit_start(self):
        metric_placeholder = {'test_dice': 0, 'val_dice': 0}
        self.logger.log_hyperparams(self.hparams, metrics=metric_placeholder)

    def forward(self, x):
        return self.net(x)

    def calculate_loss_and_dice(self, batch):
        img, mask = batch
        img = img.float()
        mask = mask.float().unsqueeze(1)
        out = self(img)
        if self.loss_function == 'dice':
            loss_val = DiceLoss(out, mask)
        elif self.loss_function == 'bce':
            loss_val = F.binary_cross_entropy(out, mask)
        elif self.loss_function == 'combo':
            loss_val = ComboLoss(out, mask)
        dice = DiceMetric(out, mask)
        return loss_val, dice

    def training_step(self, batch, batch_nb):
        loss_val, dice = self.calculate_loss_and_dice(batch)
        log_dict = {'train_loss': loss_val, 'train_dice': dice}
        self.log_dict(log_dict)
        return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}

    def validation_step(self, batch, batch_idx):
        loss_val, dice = self.calculate_loss_and_dice(batch)
        log_dict = {'val_loss': loss_val, 'val_dice': dice}
        self.log_dict(log_dict)
        return {'val_loss': loss_val, 'val_dice': dice}

    def test_step(self, batch, batch_idx):
        loss_val, dice = self.calculate_loss_and_dice(batch)
        log_dict = {'test_loss': loss_val, 'test_dice': dice}
        self.log_dict(log_dict)
        return {'test_dice': dice}

    def validation_epoch_end(self, outputs):
        loss_val = sum(output['val_loss'] for output in outputs) / len(outputs)
        dice_val = sum(output['val_dice'] for output in outputs) / len(outputs)
        log_dict = {'val_loss': loss_val, 'val_dice': dice_val}
        return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict, 'val_dice': log_dict['val_dice']}

    def test_epoch_end(self, outputs):
        dice_val = sum(output['test_dice'] for output in outputs) / len(outputs)
        log_dict = {'test_dice': dice_val}
        return {'log': log_dict, 'progress_bar': log_dict, 'test_dice': log_dict['test_dice']}

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        if self.scheduler == 'cosine':
          sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.epochs/2)
        else:
          lmbd = lambda epoch: 1
          sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lmbd)
        return [opt], [sch]

    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)

print('Segmentation model defined.')

Segmentation model defined.


# Model Training

Now that everything's set up, we get to the fun part: actually training and evaluating the model. 



## Specify hyperparameters
As we've seen during the lectures, models usually have a set of parameters that are not optimized during the training procedure, but that are nevertheless changeable and can have a dramatic effect on the results of the training. These parameters are known as hyperparameters.

The hyperparameters of a model are set before a training run. As you might remember, the whole purpose of the validation set was to find the set of hyperparameters that works best for this model on this dataset.

Finding the optimal hyperparameters usually means training the model several times, each time slightly changing one of the hyperparameters, in order to see which combination of hyperparameters leads to the best performance on the validation set.

When you are done figuring out what the best hyperparameters are for you dataset and model, you can evaluate the model on the test set to get a final, independent, estimate of the true model performance.

**NOTE**: Please again make sure to set the 'root' path here to the correct folder. It should point to the same location as the 'root_path' you used before for the data inspection, i.e. the folder that contains the 'images' and 'masks' folders.

In [None]:
# Here, the hyperparameters for the model are defined. 

hparams = {
            'root': '/gdrive/My Drive/JSRT_SCR dataset',
#            'root':root_path,
            'batch_size': 3,
            'epochs': 25,
            'encoder_depth': 5,
            'lr': 1e-3,
            'lr_scheduler': 'constant',
            'loss_function': 'combo',
            'model_backbone': 'resnet34'
           }

# The name of the model that will be trained. Make sure to change this before 
# starting a next training run, to avoid overwriting your previously trained 
# models!
model_name = 'version_0'

## The training procedure

The cell below contains the code to run the model training. Once you run it, it will create folders to store the model weights. It will also show you how the training is progressing.

After each epoch, the model is evaluated on the validation set. If the dice coefficient on the validation set improved, a model 'checkpoint' is created, which just means that the current weights of the model are saved to disk. 

After the training completes and the model has been trained for the specified number of epochs, the latest model checkpoint is loaded (so the 'best possible' model that was trained is used) and that model is used to make predictions on the test set. The dice coefficient on the test set is then calculated.

You don't have to change anything in this cell, but you can have a look at the code to see the different steps that are executed for the training.

In [None]:
# 1 INIT LIGHTNING MODEL
model = SegModel(hparams)

# 2 Create folder to save the models
checkpoint_path = os.path.join(os.getcwd(), 'pytorch_checkpoints')
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)

# 3 INIT TRAINER
trainer = pl.Trainer(
    gpus=1,
    max_epochs=hparams['epochs'],
    logger=TensorBoardLogger('lightning_logs/'),
    checkpoint_callback = True,
    callbacks = [ModelCheckpoint(dirpath=checkpoint_path, 
                filename='AMV-{epoch:02d}-{val_dice:.2f}',
                monitor='val_dice',
                save_top_k=1,
                every_n_epochs = 1,
                verbose=True,
                mode='max')]
    )

# 4 START TRAINING
trainer.fit(model)

# 5 Evaluate model on test set
trainer.test()

MisconfigurationException: ignored

## Monitoring the training

A great tool for monitoring model training is 'Tensorboard'. Among other things, it can show you the train / validation / test loss of the model after each epoch, and allows for comparison between different training runs. Execute the cell below to fire up tensorboard (wait for a minute for the app to load). 

Once it fires up, look for the slider called 'Smoothing' and set it to zero (smoothing is only useful for much more extensive experiments than what we'll do here). 

Explore the graphs that are shown in tensorboard for a bit, and try to understand what each graphs shows. In particular, look at the 'train_loss' and 'val_loss' graphs: These show the evolution of the loss during training.




In [None]:
#@title
%tensorboard --logdir lightning_logs

# Model Evaluation

To evaluate model performance, performance metrics are used. For a segmentation task such as here, a popular metric is the dice coefficient. Metrics are a very convenient way to measure performance because they can easily be averaged over the entire dataset (or the train / validation / test sets). They therefore allow you to summarize model performance with a single scalar. 

However, in order to get a feeling for model performance it is equally important to visually inspect the predictions from the model, to see if they make sense. In this section, we will both visualize the predictions and calculate the mean dice coefficients for the different subsets.



## Visualizing the results

Execute the code cell below. This will bring up a widget in which you can select a model checkpoint and a dataset to evaluate. Use the dropdown boxes to select an appropriate checkpoint and dataset (for instance the test or validation set). 

Next, you can use the slider to walk through the images in the dataset and inspect the image (left), mask (middle) and model prediction (right) for each image.


In [None]:
#@title
# -----------------------------------
# HELPER FUNCTIONS FOR VISUALIZATION
# -----------------------------------

def predict_on_image(pytorch_model, index=0, dataset='test'):
  # Function to generate a prediction using pytorch_model,
  # on any of the images from the test set
  dataset_map = {'test':pytorch_model.testset, 
                 'validation': pytorch_model.validset, 
                 'train': pytorch_model.trainset}
  ds = dataset_map[dataset]
  img, mask = ds.give_data(index)
  pred = pytorch_model.eval()(img.float().cuda(device=0).unsqueeze(0))
  pred = pred.cpu().detach().numpy().squeeze()
  return img.cpu().squeeze().numpy(), mask, pred

def dice_coefficient(y_true, y_pred, empty_score=1.0, mode='hard'):
    # Function to calculate dice coefficient after thresholding
    if mode == 'hard':
      y_th = y_pred > 0.5
    elif mode == 'soft':
      y_th = y_pred
    else:
      raise ValueError('Invalid dice mode! Choose either "soft" or "hard"')
    im1 = y_true
    im2 = y_th
    if im1.shape != im2.shape:
        raise ValueError("Shape mismatch: im1 and im2 must have the same shape.")
    im_sum = im1.sum() + im2.sum()
    if im_sum == 0:
        return empty_score
    intersection = (im1*im2).sum()

    return (2. * intersection.sum()) / im_sum

def predict_and_plot(pytorch_model, index=0, dataset='test', dice_mode='soft'):
    img, mask, pred = predict_on_image(pytorch_model, index, dataset)
    dice = dice_coefficient(mask, pred, mode=dice_mode)
    print('Dice coefficient for {} image {} is: {:.3f}'.format(dataset, index, dice))
    image = np.dstack([img]*3)
    image_masked = copy.deepcopy(image)
    image_predicted = copy.deepcopy(image)
    prediction = pred > 0.5

    image_masked[:,:,0][mask!=0] = image_masked[:,:,0][mask!=0]*1
    image_masked[:,:,1][mask!=0] = image_masked[:,:,1][mask!=0]*0
    image_masked[:,:,2][mask!=0] = image_masked[:,:,2][mask!=0]*0

    image_predicted[:,:,0][prediction!=0] = image_predicted[:,:,0][prediction!=0]*0
    image_predicted[:,:,1][prediction!=0] = image_predicted[:,:,1][prediction!=0]*0
    image_predicted[:,:,2][prediction!=0] = image_predicted[:,:,2][prediction!=0]*1

    display_image = np.hstack((image, image_masked, image_predicted))
    fig = plt.figure(figsize=(12,4))
    plt.imshow(display_image, cmap='bone')
    plt.axis('off')
    return dice, fig 

cps = glob(os.path.join(checkpoint_path, '*.ckpt'))
datasets = ['train', 'validation', 'test']

@interact(checkpoint=cps, dataset=datasets)
def select_model_checkpoint(checkpoint=cps[0], dataset='validation'):
  global eval_model, ds, model_checkpoint
  model_checkpoint = checkpoint
  eval_model = SegModel.load_from_checkpoint(checkpoint)
  eval_model.cuda(device=0)
  ds = dataset

dsmap = {'test':eval_model.ntest, 'validation':eval_model.nvalid, 'train':eval_model.ntrain}
@interact
def plot_sample(index=(0,dsmap[ds]-1,1)):
  print('Evaluating model from checkpoint: {}'.format(model_checkpoint))
  predict_and_plot(pytorch_model=eval_model, index=index, dataset=ds, dice_mode='soft')

## The mean dice coefficient
The code below calculates the dice coefficient for the model (loaded from the checkpoint you selected with the widget above) for the training, validation and test datasets. To do so, predictions are generated for all images in the datasets, so the code takes some time to execute. 


In [None]:
#@title
print('Evaluating model from checkpoint: {}'.format(model_checkpoint))
for dsname, dslen in dsmap.items():
  dice_mean = 0
  for index in range(dslen):
    _, mask, pred = predict_on_image(eval_model, index, dataset=dsname)
    dice_mean += dice_coefficient(mask, pred, mode='soft')

  dice_mean /= dslen
  print('Mean dice coefficient on the {} set is {:.3f}'.format(dsname, dice_mean))

# Assignment

As you have probably noticed, the results from the first training run do not look great. Also, the dice is not very high. This can be a result of a poor choice of hyperparameters. **The task you have in this assignment is to improve the model, by experimentally determining better choices for the hyperparameters.** 

---


**Hyperparameter interpretation**

The parameters you can change are:

*   Number of epochs
*   Learning rate
*   Loss function (valid choices are 'bce', 'dice' or 'combo')
*   Model backbone (valid choices are 'resnet18' or 'resnet34')

*HINT*: If you look at the graph for 'val_loss' in tensorboard, you probably see that the validation loss is still decreasing, even for the last epoch. This can indicate two things:

1.   We stopped the training too soon
2.   Our learning rate is too low

Of course, points 1 and 2 are related: We can either try training for more epochs, or increasing the learning rate. The learning rate I picked to start with is rather low: I'd recommend increasing it a bit as a first attempt to improve the model.


The 'loss function' hyperparameter controls the optimization objective of our training. Right now it's set to bce (binary crossentropy), but since we actually care about maximizing the dice coefficient (instead of minimizing the bce) it might make sense to use that as an objective, instead. The 'combo' loss is a sum of both bce and dice.


The 'model backbone' hyperparameter controls the structure of the model itself: ResNet18 is a simple, relatively shallow architecture that consists of 18 layers, while ResNet34 is deeper, consisting of 34 layers. Deeper models can lead to better results, but are also more prone to overfitting.