## Brain Tumor MRI - Instance Segmentation - VNET - Pytorch

In [None]:
import os

import albumentations
import cv2
import glob
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import pathlib
from pathlib import Path
from PIL import Image
import random
from skimage.io import imread
from skimage.transform import resize
import tensorboard
import torch
import torch.nn as nn
from torch.utils import data
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as fn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms 
from tqdm.notebook import tqdm


from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython.display import display
from IPython.display import clear_output


from model import VNet
from trainer import Trainer
from transformations import (
    normalize_01,
    re_normalize, 
    transforms,
    ComposeDouble,
    FunctionWrapperDouble,
    create_dense_target,
    AlbuSeg3d,
)
from utils import (
    get_filenames_of_path,
    postprocess, 
    draw_segmentation_map, 
    segmentation_target, 
    segmentation_pred, 
    save_gif
)

In [None]:
%load_ext tensorboard

## Dataset

In [None]:
class SegmentationDataSet3(data.Dataset):
    """Image segmentation dataset with caching, pretransforms and multiprocessing."""

    def __init__(
        self,
        inputs: list,
        targets: list,
        transform=None,
        use_cache: bool = False,
        pre_transform=None,
    ):
        self.inputs = inputs
        self.targets = targets
        self.transform = transform
        self.inputs_dtype = torch.float32
        self.targets_dtype = torch.long
        self.use_cache = use_cache
        self.pre_transform = pre_transform

        if self.use_cache:
            from itertools import repeat
            from multiprocessing import Pool

            with Pool() as pool:
                self.cached_data = pool.starmap(
                    self.read_images, zip(inputs, targets, repeat(self.pre_transform))
                )

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index: int):
        if self.use_cache:
            x, y = self.cached_data[index]
        else:
            # Select the sample
            input_ID = self.inputs[index]
            target_ID = self.targets[index]

            # Load input and target
            x, y = imread(str(input_ID)), imread(str(target_ID))

        # Preprocessing
        if self.transform is not None:
            x, y = self.transform(x, y)

        # Typecasting
        x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(self.targets_dtype)
        return x, y

    @staticmethod
    def read_images(inp, tar, pre_transform):
        inp, tar = imread(str(inp)), imread(str(tar))
        if pre_transform:
            inp, tar = pre_transform(inp, tar)
        return inp, tar

## Transformations

In [None]:
transforms_training = ComposeDouble(
    [
         FunctionWrapperDouble(resize, input=True, target=False, output_shape=(16, 128, 128)),
         FunctionWrapperDouble(resize, input=False, target=True, output_shape=(16, 128, 128), order=0, anti_aliasing=False, preserve_range=True),
         AlbuSeg3d(albumentations.HorizontalFlip(p=0.5)),
         AlbuSeg3d(albumentations.VerticalFlip(p=0.5)),
         AlbuSeg3d(albumentations.Rotate(p=0.5)),
         AlbuSeg3d(albumentations.RandomRotate90(p=0.5)),
         FunctionWrapperDouble(create_dense_target, input=False, target=True),
         FunctionWrapperDouble(np.expand_dims, axis=0),
         #RandomFlip(ndim_spatial=3),
         FunctionWrapperDouble(normalize_01),
    ]
)

transforms_testing = ComposeDouble(
    [
         FunctionWrapperDouble(resize, input=True, target=False, output_shape=(16, 128, 128)),
         FunctionWrapperDouble(resize, input=False, target=True, output_shape=(16, 128, 128), order=0, anti_aliasing=False, preserve_range=True),
         #AlbuSeg3d(albumentations.HorizontalFlip(p=0.5)),
         #AlbuSeg3d(albumentations.VerticalFlip(p=0.5)),
         #AlbuSeg3d(albumentations.Rotate(p=0.5)),
         #AlbuSeg3d(albumentations.RandomRotate90(p=0.5)),
         FunctionWrapperDouble(create_dense_target, input=False, target=True),
         FunctionWrapperDouble(np.expand_dims, axis=0),
         # RandomFlip(ndim_spatial=3),
         FunctionWrapperDouble(normalize_01),
    ]
)

## Data Loading

In [None]:
root_train = pathlib.Path.cwd() / "Data3D/train"
root_val = pathlib.Path.cwd() / "Data3D/val"

# input and target files
inputs_train = get_filenames_of_path(root_train / "Input")
targets_train = get_filenames_of_path(root_train / "Target")

inputs_val = get_filenames_of_path(root_val / "Input")
targets_val = get_filenames_of_path(root_val / "Target")

In [None]:
# dataset training
dataset_train = SegmentationDataSet3(
    inputs=inputs_train,
    targets=targets_train,
    transform=transforms_training,
    use_cache=False,
    pre_transform=None,
)

# dataset training
dataset_val = SegmentationDataSet3(
    inputs=inputs_val,
    targets=targets_val,
    transform=transforms_training,
    use_cache=False,
    pre_transform=None,
)

# dataloader training
dataloader_training = DataLoader(
    dataset=dataset_train,
    batch_size=1,
    # batch_size of 2 won't work because the depth dimension is different between the 2 samples
    shuffle=True,
)

dataloader_validation = DataLoader(
    dataset=dataset_val,
    batch_size=1,
    # batch_size of 2 won't work because the depth dimension is different between the 2 samples
    shuffle=False,
)

## Training 

In [None]:
import monai

# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Summary Writer TensorBoard
writer = SummaryWriter('runs/vnet')

# Monai VNET model
#model = monai.networks.nets.VNet(spatial_dims=3, in_channels=1, out_channels=3, act=('elu', {'inplace': True}), 
                                 #dropout_prob=0.5, dropout_dim=3, bias=False).to(device)
                        
# Paper Vnet model
model = VNet(elu=True, in_channels=1, classes=3).to(device)

lr=0.0001

#### Loss Functions

In [None]:
from monai.losses.dice import DiceLoss

# Cross Entropy Loss : 
cross_entropy_loss = torch.nn.CrossEntropyLoss()

# Dice Loss : 
dice_loss = DiceLoss(reduction='mean', to_onehot_y=True, sigmoid=True)

#### Optimizers

In [None]:
#Adam : 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

#### Learning rate scheduler

In [None]:
#ReduceLROnPlateau : 
plateau_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# StepLR :  
step_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# MultiplicativeLR : 
multiplicative_lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95) 

# lambdaLR :  
lambda_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)

# CosineLR : 
cosine_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)

In [None]:
# trainer
trainer = Trainer(
    model=model,
    device=device,
    criterion=dice_loss,
    optimizer=optimizer,
    training_dataloader=dataloader_training,
    validation_dataloader=dataloader_validation,
    lr_scheduler=plateau_lr_scheduler,
    epochs=500,
    epoch=0,
    writer = writer,
    notebook=True,
)

# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()

clear_output()

# Plot results

In [None]:
%tensorboard --logdir=runs

In [None]:
from visual import plot_training

fig = plot_training(
    training_losses,
    validation_losses,
    lr_rates,
    gaussian=True,
    sigma=1,
    figsize=(10, 4),
)

# Learning rate finder

In [None]:
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Vnet model
model = VNet(elu=True, in_channels=1, classes=4).to(device)

# criterion
criterion = torch.nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
from lr_rate_finder import LearningRateFinder

lrf = LearningRateFinder(model, criterion, optimizer, device)
lrf.fit(dataloader_training, steps=1000)

clear_output()

In [None]:
lrf.plot()


## Testing

In [None]:
root_test = pathlib.Path.cwd() / "Data3D/test"

# input and target files
inputs_test = get_filenames_of_path(root_test / "Input")
targets_test = get_filenames_of_path(root_test / "Target")

dataset_test = SegmentationDataSet3(
    inputs=inputs_test,
    targets=targets_test,
    transform=transforms_testing,
    use_cache=False,
    pre_transform=None,
)

dataloader_test = DataLoader(
    dataset=dataset_test,
    batch_size=1,
    # batch_size of 2 won't work because the depth dimension is different between the 2 samples
    shuffle=False,
)


In [None]:
from IPython import display

model.eval()
for indice, (image, mask) in enumerate(dataloader_test):
    input_image, input_mask = image.to(device), mask.to(device)
    with torch.no_grad():
        output = model(input_image) # send through model/network
        output = postprocess(output) # postprocess the prediction
        
        filepath_out_target = f'{root_test}/Segmentation/segmentation_target_{indice}.gif'
        filepath_out_pred = f'{root_test}/Segmentation/segmentation_prediction_{indice}.gif'
        
        segmentation_target_path = segmentation_target(input_image, input_mask, filepath_out_target)
        segmentation_pred_path = segmentation_pred(input_image, output, filepath_out_pred)
        
        with open(segmentation_target_path,'rb') as f:
            display.Image(data=f.read(), format='png')

        with open(segmentation_pred_path,'rb') as f:
            display.Image(data=f.read(), format='png')

## Save the model

In [None]:
# save the model
model_name = "brain_mri_vnet.pt"
torch.save(model.state_dict(), pathlib.Path.cwd() / model_name)
