In [24]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import gzip
import os
import cv2
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
import torch

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import sys

# for relative imports to work in notebooks
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from model.model  import BraTS2021BaseUnetModel, BraTS2021AttentionUnetModel_V4


Select which BRaTS2021 case we want

In [25]:
sample_id = '00166'

Import pre-trained model

In [26]:
# model is located at: 
unet_base_model_path        = '../saved/models/base_unet_dice_no_augm/model_best.pth'
unet_base_aug_model_path    = '../saved/models/base_unet_dice/model_best.pth'
unet_attention_model_path   = '../saved/models/attention_unet_dice_v4_all_aug_30_124/model_best.pth'

# load base unet model
base_model      = BraTS2021BaseUnetModel()
base_aug_model  = BraTS2021BaseUnetModel()
att_model       = BraTS2021AttentionUnetModel_V4()

# get checkpoints
base_model_checkpoint = torch.load(unet_base_model_path)
base_aug_model_checkpoint = torch.load(unet_base_aug_model_path)
att_model_checkpoint = torch.load(unet_attention_model_path)

# get state dicts
base_model_state_dict = base_model_checkpoint['state_dict']
base_aug_model_state_dict = base_aug_model_checkpoint['state_dict']
att_model_state_dict = att_model_checkpoint['state_dict']

# load state dicts
base_model.load_state_dict(base_model_state_dict)
base_aug_model.load_state_dict(base_aug_model_state_dict)
att_model.load_state_dict(att_model_state_dict)


<All keys matched successfully>

Load all modalities and segmentation groundtruth masks for one case

In [27]:
# define path
base_path = '../data/BRaTS2021/BRaTS2021_raw/' 

Flair       = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_flair.nii.gz').get_fdata()
seg_target  = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_seg.nii.gz').get_fdata()
T1          = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_t1.nii.gz').get_fdata()
T1ce        = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_t1ce.nii.gz').get_fdata()
T2          = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_t2.nii.gz').get_fdata()

# convert from 0, 1, 2, 4 --> 0, 1, 2, 3
seg_target[seg_target == 4] = 3  

# print number of slices
imgshape = Flair.shape
print(f"Image resolution: {imgshape[0]}x{imgshape[1]}")
print(f"Number of slices: {imgshape[2]}")

Image resolution: 240x240
Number of slices: 155


Create transformations for data augmentation

In [28]:
from transformations.transformations import brats_validation_transform

brats_transform = brats_validation_transform(image_keys=['t1', 't1ce', 't2', 'flair'], 
                                             all_keys=['t1', 't1ce', 't2', 'flair', 'seg'],)

Load the specific case we want

In [29]:
# MRI modalities
modalities = ['t1', 't1ce', 't2', 'flair'] 

# placeholders 
slice_dict = {}

# loop over all 155 slices
for slice_id in range(155):
    # create dictionary for each slice
    slice_dict[slice_id] = {}

    # get all 4 modalities
    t1 = T1[:, :, slice_id]
    t1ce = T1ce[:, :, slice_id]
    t2 = T2[:, :, slice_id]
    flair = Flair[:, :, slice_id]
    seg_tar = seg_target[:, :, slice_id]

    # create dictionary for each slice 
    slice_dict[slice_id]['t1'] = t1
    slice_dict[slice_id]['t1ce'] = t1ce
    slice_dict[slice_id]['t2'] = t2
    slice_dict[slice_id]['flair'] = flair
    slice_dict[slice_id]['seg'] = seg_tar

    # apply transformations
    slice_dict[slice_id] = brats_transform(slice_dict[slice_id])

Test the loaded case to get the prediction

In [30]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# move all models to GPU
base_model.to(device)
base_aug_model.to(device)
att_model.to(device)

# set all models to eval mode
base_model.eval()
base_aug_model.eval()
att_model.eval()

seg_preds_base = []
seg_preds_base_aug = []
seg_preds_att = []

# n_cases = 2

with torch.no_grad():
    for slice_idx in range(155):
        modality_dict = slice_dict[slice_idx]

        x = torch.cat((modality_dict['flair'], 
                       modality_dict['t1'], 
                       modality_dict['t1ce'], 
                       modality_dict['t2']), dim=0)

        # add batch dimension 1
        x = x.unsqueeze(0)

        # inference
        x = x.to(device)
        seg_pred_base = base_model(x)
        seg_pred_base_aug = base_aug_model(x)
        seg_pred_att = att_model(x)

        # remove batch dimension 1

        # append to list
        seg_preds_base.append(seg_pred_base.cpu())
        seg_preds_base_aug.append(seg_pred_base_aug.cpu())
        seg_preds_att.append(seg_pred_att.cpu())
        
        # out = model(x)
        # seg_preds.append(out)

print("shape of seg_preds_base: ", len(seg_preds_base))

shape of seg_preds_base:  155


Stack them into shape (155, 4, 240, 240)

In [31]:
# vstack all slices
seg_preds_base_list = torch.vstack(seg_preds_base)
seg_preds_base_aug_list = torch.vstack(seg_preds_base_aug)
seg_preds_att_list = torch.vstack(seg_preds_att)

print(len(seg_preds_base))

155


Detach from GPU

In [32]:
# argmax an transpose
seg_preds_base      = torch.argmax(seg_preds_base_list, dim=1).transpose(0, 2).cpu().detach().numpy()
seg_preds_base_aug  = torch.argmax(seg_preds_base_aug_list, dim=1).transpose(0, 2).cpu().detach().numpy()
seg_preds_att       = torch.argmax(seg_preds_att_list, dim=1).transpose(0, 2).cpu().detach().numpy()

print (seg_preds_att.shape)

(240, 240, 155)


Visualize segmentation prediction vs groundtruth (overlay on top of Flair)

In [33]:
classes_dict = {
    0 : 'B/W = healthy',
    1 : 'Red = necrotic',
    2 : 'Green = edematous',
    3 : 'Blue = enhancing'
}
print(np.shape(seg_target))

# change colours of segmentation result  
color_segmentation = np.zeros((240, 240, 155, 3), dtype=np.uint8)
color_segmentation_pred_base = np.zeros((240, 240, 155, 3), dtype=np.uint8)
color_segmentation_pred_base_aug = np.zeros((240, 240, 155, 3), dtype=np.uint8)
color_segmentation_pred_att = np.zeros((240, 240, 155, 3), dtype=np.uint8)


# replace 4 with 3
seg_target[seg_target == 4] = 3

# target                                             # Black (healthy tissue) = 0
color_segmentation[seg_target == 1] = [255, 0, 0]    # Red (necrotic tumor core) = 1
color_segmentation[seg_target == 2] = [0, 255, 0]    # Green (peritumoral edematous/invaded tissue) = 2
color_segmentation[seg_target == 3] = [0, 0, 255]    # Blue (enhancing tumor) = 4

# seg_preds_base                                                 # Black (healthy tissue) = 0
color_segmentation_pred_base[seg_preds_base == 1] = [255, 0, 0]  # Red (necrotic tumor core) = 1
color_segmentation_pred_base[seg_preds_base == 2] = [0, 255, 0]  # Green (peritumoral edematous/invaded tissue) = 2
color_segmentation_pred_base[seg_preds_base == 3] = [0, 0, 255]  # Blue (enhancing tumor) = 4

# seg_preds_base_aug                                                     # Black (healthy tissue) = 0
color_segmentation_pred_base_aug[seg_preds_base_aug == 1] = [255, 0, 0]  # Red (necrotic tumor core) = 1
color_segmentation_pred_base_aug[seg_preds_base_aug == 2] = [0, 255, 0]  # Green (peritumoral edematous/invaded tissue) = 2
color_segmentation_pred_base_aug[seg_preds_base_aug == 3] = [0, 0, 255]  # Blue (enhancing tumor) = 4


# seg_preds_att                                                # Black (healthy tissue) = 0
color_segmentation_pred_att[seg_preds_att == 1] = [255, 0, 0]  # Red (necrotic tumor core) = 1
color_segmentation_pred_att[seg_preds_att == 2] = [0, 255, 0]  # Green (peritumoral edematous/invaded tissue) = 2
color_segmentation_pred_att[seg_preds_att == 3] = [0, 0, 255]  # Blue (enhancing tumor) = 4

# alpha
a = 0.3

# text
x_slice = 40

def create_seg_figure(background, color_seg_tar, color_segmentation_pred_base_slice, color_segmentation_pred_base_aug_slice, color_segmentation_pred_att_slice, slice_idx):

    plt.figure(figsize=(20, 20))

    subtext = f"BRaTS2021_{sample_id}"
    slice_txt = f"{slice_idx:03d}"

    # target
    plt.subplot(1, 4, 1)
    plt.imshow(background, cmap='gray')
    plt.imshow(color_seg_tar, cmap='bone', alpha=a)
    plt.title("target", fontsize=20)
    plt.axis('off')
    plt.tight_layout()
    plt.text(45, 230, subtext, fontsize=20, color='white')
    plt.text(x_slice, 30, slice_txt, fontsize=30, color='white')

    # U-net
    plt.subplot(1, 4, 2)
    plt.imshow(background, cmap='gray')
    plt.imshow(color_segmentation_pred_base_slice, cmap='bone', alpha=a)
    plt.title("U-net", fontsize=20)
    plt.axis('off')
    plt.tight_layout()
    plt.text(45, 230, subtext, fontsize=20, color='white')
    plt.text(x_slice, 30, slice_txt, fontsize=30, color='white')

    # U-net + augmentations
    plt.subplot(1, 4, 3)
    plt.imshow(background, cmap='gray')
    plt.imshow(color_segmentation_pred_base_aug_slice, cmap='bone', alpha=a)
    plt.title("U-net+augmentations", fontsize=20)
    plt.axis('off')
    plt.tight_layout()
    plt.text(45, 230, subtext, fontsize=20, color='white')
    plt.text(x_slice, 30, slice_txt, fontsize=30, color='white')

    # Attention U-net
    plt.subplot(1, 4, 4)
    plt.imshow(background, cmap='gray')
    plt.imshow(color_segmentation_pred_att_slice, cmap='bone', alpha=a)
    plt.title("Attention U-net", fontsize=20)
    plt.axis('off')
    plt.tight_layout()
    plt.text(45, 230, subtext, fontsize=20, color='white') # case and slice id
    plt.text(x_slice, 30, slice_txt, fontsize=30, color='white')


    return plt


def visualize_3d_labels(layer):
    color_seg_tar = color_segmentation[:, :, layer, :]
    color_segmentation_pred_base_slice = color_segmentation_pred_base[:, :, layer, :]
    color_segmentation_pred_base_aug_slice = color_segmentation_pred_base_aug[:, :, layer, :]
    color_segmentation_pred_att_slice = color_segmentation_pred_att[:, :, layer, :]

    # print segmentation result
    print([classes_dict[int(result)] for result in np.unique(seg_target[:, :, layer])])
 
    background = T2[:, :, layer]
    plot = create_seg_figure(background, color_seg_tar, color_segmentation_pred_base_slice, color_segmentation_pred_base_aug_slice, color_segmentation_pred_att_slice, layer)
    plot.show()

    return layer

interact(visualize_3d_labels, layer=(0, Flair.shape[2] - 1))

(240, 240, 155)


interactive(children=(IntSlider(value=77, description='layer', max=154), Output()), _dom_classes=('widget-inte…

<function __main__.visualize_3d_labels(layer)>