## Brain Tumor MRI - Instance Segmentation - UNET 3D Sclices - Pytorch

In [None]:
import os

import albumentations
import glob
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import pathlib
from pathlib import Path
from PIL import Image
import random
from skimage.io import imread
from skimage.transform import resize
import tensorboard
from tqdm.notebook import tqdm

import torch
from torch.utils import data
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as fn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms 

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython.display import display
from IPython.display import clear_output

#from other files
from model import UNet
from trainer import Trainer
from transformations import (
    normalize_01,
    re_normalize, 
    transforms, 
    crop_sample, 
    pad_sample, 
    resize_sample, 
    normalize_volume
)
from utils import (
    predict, 
    preprocess, 
    preprocess_images,
    postprocess, 
    draw_segmentation_map,
    save_gif, 
    segmentation_target, 
    segmentation_pred
)

In [None]:
%load_ext tensorboard

## Dataset

In [None]:
class BrainSegmentationDataset(Dataset):
    """Brain MRI dataset for FLAIR abnormality segmentation"""

    in_channels = 3
    out_channels = 1

    def __init__(
        self,
        images_dir,
        transform=None,
        image_size=256,
        subset="train",
        random_sampling=True,
        validation_cases=0,
        seed=42,
    ):
        assert subset in ["all", "train", "validation"]

        # read images
        volumes = {}
        masks = {}
        print("reading {} images...".format(subset))
        for (dirpath, dirnames, filenames) in os.walk(images_dir):
            image_slices = []
            mask_slices = []
            for filename in sorted(
                filter(lambda f: ".tif" in f, filenames),
                key=lambda x: int(x.split(".")[-2].split("_")[4]),
            ):
                filepath = os.path.join(dirpath, filename)
                if "mask" in filename:
                    mask_slices.append(imread(filepath, as_gray=True))
                else:
                    image_slices.append(imread(filepath))
                #print("dirpath", dirpath.split("/"))
            if len(image_slices) > 0:
                patient_id = dirpath.split("/")[-1]
                volumes[patient_id] = np.array(image_slices[1:-1])
                masks[patient_id] = np.array(mask_slices[1:-1])
                #print("patient_id", patient_id, "volumes", len(volumes), "masks", len(masks))

        self.patients = sorted(volumes)
        print("self.patients", self.patients)

        # select cases to subset
        if not subset == "all":
            random.seed(seed)
            validation_patients = random.sample(self.patients, k=validation_cases)
            print("validation_patients", validation_patients)
            if subset == "validation":
                self.patients = validation_patients
            else:
                self.patients = sorted(
                    list(set(self.patients).difference(validation_patients))
                )
            print("self.patients", self.patients)

        print("preprocessing {} volumes...".format(subset))
        # create list of tuples (volume, mask)
        self.volumes = [(volumes[k], masks[k]) for k in self.patients]

        print("cropping {} volumes...".format(subset))
        # crop to smallest enclosing volume
        self.volumes = [crop_sample(v) for v in self.volumes]

        print("padding {} volumes...".format(subset))
        # pad to square
        self.volumes = [pad_sample(v) for v in self.volumes]

        print("resizing {} volumes...".format(subset))
        # resize
        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]

        print("normalizing {} volumes...".format(subset))
        # normalize channel-wise
        self.volumes = [(normalize_volume(v), m) for v, m in self.volumes]

        # probabilities for sampling slices based on masks
        self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes]
        self.slice_weights = [
            (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights
        ]

        # add channel dimension to masks
        self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes]

        print("done creating {} dataset".format(subset))

        # create global index for patient and slice (idx -> (p_idx, s_idx))
        num_slices = [v.shape[0] for v, m in self.volumes]
        self.patient_slice_index = list(
            zip(
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
                sum([list(range(x)) for x in num_slices], []),
            )
        )

        self.random_sampling = random_sampling

        self.transform = transform

    def __len__(self):
        return len(self.patient_slice_index)

    def __getitem__(self, idx):
        patient = self.patient_slice_index[idx][0]
        slice_n = self.patient_slice_index[idx][1]

        if self.random_sampling:
            patient = np.random.randint(len(self.volumes))
            slice_n = np.random.choice(
                range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient]
            )

        v, m = self.volumes[patient]
        image = v[slice_n]
        mask = m[slice_n]

        if self.transform is not None:
            image, mask = self.transform((image, mask))

        # fix dimensions (C, H, W)
        image = image.transpose(2, 0, 1)
        mask = mask.transpose(2, 0, 1).squeeze(0)
        
        image = (image - np.min(image)) / np.ptp(image)
        #print("np.min(mask)", np.min(mask), "np.ptp(mask)", np.ptp(mask))
        if np.ptp(mask) != 0 :
            mask = (mask - np.min(mask)) / np.ptp(mask)
        
        image_tensor = torch.from_numpy(image.astype(np.float32))
        mask_tensor = torch.from_numpy(mask.astype(np.int64))

        # return tensors
        return image_tensor, mask_tensor

## Data Loading

In [None]:
images_dir = './Data3D/train/'

# dataset training
dataset_train = BrainSegmentationDataset(
        images_dir=images_dir,
        subset="train",
        image_size=128,
        transform=transforms(scale=0.05, angle=15, flip_prob=0.5),
        validation_cases=30,
    )

# dataset validation
dataset_valid = BrainSegmentationDataset(
        images_dir=images_dir,
        subset="validation",
        image_size=128,
        transform=transforms(scale=0.05, angle=15, flip_prob=0.5),
        validation_cases=30,
    )

# dataloader training
dataloader_training = DataLoader(dataset=dataset_train, batch_size=4, shuffle=True)

# dataloader validation
dataloader_validation = DataLoader(dataset=dataset_valid, batch_size=4, shuffle=True)

## Training 

In [None]:
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

# Summary Writer TensorBoard
writer = SummaryWriter('runs/unet3D')

# model
model = UNet(
    in_channels=3,
    out_channels=2,
    n_blocks=4,
    start_filters=32,
    activation="relu",
    normalization="batch",
    conv_mode="same",
    dim=2,
).to(device)


#Learning Rate
lr=0.001

#### Optimizers

In [None]:
# SGD
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)

#Adam
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

#Adamax
#optimizer = torch.optim.Adamax(model.parameters(), lr=lr)

#RMSprop
#optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)

#### Loss Function

In [None]:
from monai.losses.dice import DiceLoss, FocalLoss

# Cross Entropy Loss : 
cross_entropy_loss = torch.nn.CrossEntropyLoss()

# Dice Loss : 
dice_loss = DiceLoss(reduction='mean', to_onehot_y=True, sigmoid=True)

# Focal Loss : 
focal_loss = FocalLoss(reduction='mean', to_onehot_y=True)

#### Learning rate scheduler

In [None]:
#ReduceLROnPlateau : 
plateau_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

# StepLR :  
step_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# MultiplicativeLR : 
multiplicative_lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95) 

# lambdaLR :  
lambda_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)

# CosineLR : 
cosine_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)

In [None]:
# trainer
trainer = Trainer(
    model=model,
    device=device,
    criterion=cross_entropy_loss,
    optimizer=optimizer,
    training_dataloader=dataloader_training,
    validation_dataloader=dataloader_validation,
    lr_scheduler=plateau_lr_scheduler,
    epochs=50,
    epoch=0,
    writer = writer,
    notebook=True,
)

# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()

clear_output()

# Learning rate finder

In [None]:
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# model
model = UNet(
    in_channels=3,
    out_channels=2,
    n_blocks=4,
    start_filters=32,
    activation="relu",
    normalization="batch",
    conv_mode="same",
    dim=2,
).to(device)

# criterion
criterion = torch.nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [None]:
from lr_rate_finder import LearningRateFinder

lrf = LearningRateFinder(model, criterion, optimizer, device)
lrf.fit(dataloader_training, steps=1000)

clear_output()

In [None]:
lrf.plot()


# Plot results

In [None]:
%tensorboard --logdir=runs

## Testing

In [None]:
from IPython import display

root = f'./Data3D/test'

test_directories = list(os.listdir(root))

for directory in test_directories:
    directory_path = os.path.join(root, directory)
    inputs_test = os.path.join(directory_path, 'Input')
    targets_test = os.path.join(directory_path, 'Target')
    
    images_res, targets_res = preprocess_images(inputs_test, targets_test)
    # predict the segmentation maps
    output = [predict(img, model, preprocess, postprocess, device) for img in images_res]
    
    # Create a segmentation array for ground truth
    segmentation_target_path = segmentation_target(directory, directory_path, images_res, targets_res)
    
    # Create a segmentations array for predictions
    segmentation_pred_path = segmentation_pred(directory, directory_path, images_res, output)
    
    # Display GIF images 
    with open(segmentation_target_path,'rb') as f:
        display.Image(data=f.read(), format='png')

    with open(segmentation_pred_path,'rb') as f:
        display.Image(data=f.read(), format='png')

## Save the model

In [None]:
# save the model
model_name = "brain_mri_unet3D.pt"
torch.save(model.state_dict(), pathlib.Path.cwd() / model_name)
