## Multi modal 3D brain tumor segmentaion using the BRATS 2016-2017 dataset

### Install monai
MONAI is a PyTorch-based, open-source framework for deep learning in healthcare imaging.

In [None]:
# !pip install monai
# !pip install 'monai[all]'
# !python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
# !pip install nilearn
# !pip install git+https://github.com/miykael/gif_your_nifti

### Import relevant libraries and functionality

In [None]:
import os
import shutil
import glob
import monai
import matplotlib
import torch
import nilearn as nl
import nilearn.plotting as nlplt
import nibabel as nib
import gif_your_nifti.core as gif2nif
import matplotlib.pyplot as plt
%matplotlib inline
from skimage.io import imshow
from skimage.color import label2rgb
from skimage.util import montage 
from skimage.transform import rotate
from sklearn.model_selection import train_test_split
import numpy as np
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AsDiscrete,
    CenterSpatialCropd,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    ToTensord,
)
from monai.utils import first, set_determinism
import warnings
warnings.filterwarnings('ignore')

print_config()

In [None]:
# set seed
set_determinism(seed=0)

### Create helper functions to read, split and process the dataset

In [None]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 0 is no tumor
    label 1 is the edema
    label 2 is the non enhancing tumor
    label 3 is enhancing tumor core
    The final classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(np.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1)
            )
            # label 3 is ET
            result.append(d[key] == 2)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d

In [None]:
def _transforms():
    '''
    The function creates the tranformations to be applied to the training and validation sets
    
    input: 
        no arguments
    output: 
        training and validation transforms (pytorch transforms)
    '''
    train_transforms = Compose(
        [
            # load 4 Nifti images and stack them together
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys="image"),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            RandSpatialCropd(
                keys=["image", "label"], roi_size=[128, 128, 64], random_size=False
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys="image"),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )
    return train_transforms, val_transforms

In [None]:
def split_dataset(data_dir):
    '''
    The function reads and splits the dataset into train and validation set
    
    input: 
        data_dir: the path of the directory where data exists (type: str)
    output: 
        list of {image name : label name} pairs for train and validation sets (type: list of dictionaries)
    '''
    # get list of image and labels list
    train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
    train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
    # create a dictionary of the images and labels pairs
    data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
    # split the training set into train and set sets
    train_idx, val_idx = train_test_split(np.arange(len(train_images)), test_size=0.2)
    train_files, val_files = list(np.array(data_dicts)[train_idx]) , list(np.array(data_dicts)[val_idx])
    return train_files, val_files

### Create the dataset and dataloaders using the helper functions

In [None]:
data_dir = 'data/'
train_files, val_files = split_dataset(data_dir)
train_transforms , val_transforms = _transforms()
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=0)

### Visualizing the dataset

In [None]:
def visualise_images_and_labels(dataset,image_id,image_slice):
    '''
    The function visualises the 4 modalities (T2 FAIR, T1, T1Gd and T2)
    along with the three output segmentaion ground truths (Whole tumor, Tumor core and Enhancing tumor)
    input:
        
    '''
    plt.figure("image", (24, 6))
    print(f"image shape: {dataset[image_id]['image'].shape}")
    modes = ['T2 FLAIR','T1','T1Gd','T2']
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"{modes[i]}")
        plt.imshow(dataset[image_id]["image"][i, :, :, image_slice].detach().cpu(), cmap="gray")
    plt.show()
    # also visualize the 3 channels label corresponding to this image
    print(f"label shape: {dataset[image_id]['label'].shape}")
    plt.figure("label", (18, 6))
    labels = ['Tumor core', 'Whole tumor' , 'Enhancing tumor']
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"{labels[i]}")
        plt.imshow(dataset[image_id]["label"][i, :, :, image_slice].detach().cpu(), cmap ="gray")
    plt.show()

In [None]:
visualise_images_and_labels(val_ds,7,30)

In [None]:
# Visualise all slices of T2 FLAIR 
fig, ax1 = plt.subplots(1, 1, figsize = (15,15))
ax1.imshow(rotate(montage(val_ds[7]["image"][0,:,:,].detach().cpu()), 90, resize=True), cmap ='gray')

In [None]:
def save_gif(dataset,image_id,out_location, modality = 0):
    '''
    saves the gif depicting all the slices of a specified madality
    input:
        dataset: the pytorch dataset object
        image_id: the index of the image in the dataset
        out_location: the location and name of the output file
        modality: the index of the particular modality in the dataset
        
    output: 
        an nii image of the modality selected
        a gif image of the  modality selected
    '''
    img = dataset[image_id]["image"][0,:,:,:].detach().cpu().numpy()
    lab = dataset[image_id]["label"][0,:,:,:].detach().cpu().numpy()
    img = nib.Nifti1Image(img, np.eye(4))
    lab = nib.Nifti1Image(lab, np.eye(4))
    nib.save(img, out_location+'/test.nii')
    nib.save(lab, out_location+'/test_lab.nii')
    gif2nif.write_gif_normal(out_location+'/test.nii')

In [None]:
save_gif(val_ds, 7, 'data/samples/')

![flair_gif](data/samples/test.gif "hi")

### Build and train the 3D UNET model

In [None]:
# Build the multi-modal multi output 3D-unet model using monai 
device = torch.device("cuda:0")
model = UNet(
    dimensions=3,
    in_channels=4,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

loss_function = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)

optimizer = torch.optim.Adam(
    model.parameters(), 1e-4, weight_decay=1e-5, amsgrad=True
)

In [None]:
max_epochs = 200
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
metric_values_tc = []
metric_values_wt = []
metric_values_et = []

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
        )
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            dice_metric = DiceMetric(include_background=True, reduction="mean")
            post_trans = Compose(
                [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
            )
            metric_sum = metric_sum_tc = metric_sum_wt = metric_sum_et = 0.0
            metric_count = (
                metric_count_tc
            ) = metric_count_wt = metric_count_et = 0
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = model(val_inputs)
                val_outputs = post_trans(val_outputs)
                # compute overall mean dice
                value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                not_nans = not_nans.item()
                metric_count += not_nans
                metric_sum += value.item() * not_nans
                # compute mean dice for TC
                value_tc, not_nans = dice_metric(
                    y_pred=val_outputs[:, 0:1], y=val_labels[:, 0:1]
                )
                not_nans = not_nans.item()
                metric_count_tc += not_nans
                metric_sum_tc += value_tc.item() * not_nans
                # compute mean dice for WT
                value_wt, not_nans = dice_metric(
                    y_pred=val_outputs[:, 1:2], y=val_labels[:, 1:2]
                )
                not_nans = not_nans.item()
                metric_count_wt += not_nans
                metric_sum_wt += value_wt.item() * not_nans
                # compute mean dice for ET
                value_et, not_nans = dice_metric(
                    y_pred=val_outputs[:, 2:3], y=val_labels[:, 2:3]
                )
                not_nans = not_nans.item()
                metric_count_et += not_nans
                metric_sum_et += value_et.item() * not_nans
                
            metric = metric_sum / metric_count
            metric_values.append(metric)
            metric_tc = metric_sum_tc / metric_count_tc
            metric_values_tc.append(metric_tc)
            metric_wt = metric_sum_wt / metric_count_wt
            metric_values_wt.append(metric_wt)
            metric_et = metric_sum_et / metric_count_et
            metric_values_et.append(metric_et)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(
                    model.state_dict(),
                    os.path.join(root_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )

### Visualize the loss and the dice score per epoch

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()

plt.figure("train", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Val Mean Dice TC")
x = [val_interval * (i + 1) for i in range(len(metric_values_tc))]
y = metric_values_tc
plt.xlabel("epoch")
plt.plot(x, y, color="blue")
plt.subplot(1, 3, 2)
plt.title("Val Mean Dice WT")
x = [val_interval * (i + 1) for i in range(len(metric_values_wt))]
y = metric_values_wt
plt.xlabel("epoch")
plt.plot(x, y, color="brown")
plt.subplot(1, 3, 3)
plt.title("Val Mean Dice ET")
x = [val_interval * (i + 1) for i in range(len(metric_values_et))]
y = metric_values_et
plt.xlabel("epoch")
plt.plot(x, y, color="purple")
plt.show()

### Visualize the model results versus the ground truth 

In [None]:
def visualise_results(dataset,model_path,image_id,image_slice):
    model.load_state_dict(torch.load(model_path))
    model.eval()
    post_trans = Compose(
                [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
            )
    with torch.no_grad():
        # select one image to evaluate and visualize the model output
        val_input = dataset[image_id]["image"].unsqueeze(0).to(device)
        val_output = model(val_input)
        val_output = post_trans(val_output)
        plt.figure("image", (24, 6))
        modes = ['T2 FLAIR','T1','T1Gd','T2']
        print('Image modalities')
        for i in range(4):
            plt.subplot(1, 4, i + 1)
            plt.title(f"{modes[i]}")
            plt.imshow(dataset[image_id]["image"][i, :, :, image_slice].detach().cpu(), cmap="gray")
        plt.show()
        
        # visualize the 3 channels label corresponding to this image
        plt.figure("label", (18, 6))
        labels = ['Tumor core', 'Whole tumor' , 'Enhancing tumor']
        print('Ground truth masks')
        for i in range(3):
            plt.subplot(1, 3, i + 1)
            plt.title(f"{labels[i]}")
            plt.imshow(dataset[image_id]["label"][i, :, :, image_slice].detach().cpu(), cmap ="gray")
        plt.show()
        # visualize the 3 channels model output corresponding to this image
        plt.figure("output", (18, 6))
        print('Predicted masks')
        for i in range(3):
            plt.subplot(1, 3, i + 1)
            plt.title(f"{labels[i]}")
            plt.imshow(val_output[0, i, :, :, image_slice].detach().cpu(), cmap ="gray")
        plt.show()

In [None]:
model_path = os.path.join("best_metric_model.pth")
visualise_results(val_ds,model_path,3,20)