In [1]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import gzip
import os
import cv2
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
import torch

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import sys

# for relative imports to work in notebooks
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from model.model  import BraTS2021BaseUnetModel, BraTS2021AttentionUnetModel_V4

from monai.metrics import compute_generalized_dice, compute_meandice


Select which BRaTS2021 case we want

In [2]:
sample_id = '00349'

Import pre-trained model

In [3]:
# model is located at: 
unet_base_model_path = '../saved/models/base_unet_dice/model_best.pth'
unet_base_aug_model_path = '../saved/models/BraTS2021_Base_Unet/0616_220253/model_best.pth'
unet_attention_model_path = '../saved/models/attention_unet_dice_v4/model_best.pth'


model_path = unet_attention_model_path

# load base unet model
# model = BraTS2021BaseUnetModel()
model = BraTS2021AttentionUnetModel_V4()

checkpoint = torch.load(model_path)
state_dict = checkpoint['state_dict']
# base_unet_model.load_state_dict(state_dict)
model.load_state_dict(state_dict)
print("type of model: ", type(model))

# print model summary if you want to
print(model)



type of model:  <class 'model.model.BraTS2021AttentionUnetModel_V4'>
BraTS2021AttentionUnetModel_V4(
  (Maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv1): conv_block(
    (conv): Sequential(
      (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (Conv2): conv_block(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1

Load all modalities and segmentation groundtruth masks for one case

In [4]:
# define path
base_path = '../data/BRaTS2021/BRaTS2021_raw/' 

Flair       = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_flair.nii.gz').get_fdata()
seg_target  = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_seg.nii.gz').get_fdata()
T1          = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_t1.nii.gz').get_fdata()
T1ce        = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_t1ce.nii.gz').get_fdata()
T2          = nib.load(base_path  + 'BraTS2021_' + sample_id  + '/BraTS2021_' + sample_id + '_t2.nii.gz').get_fdata()

# convert from 0, 1, 2, 4 --> 0, 1, 2, 3
seg_target[seg_target == 4] = 3  

# print number of slices
imgshape = Flair.shape
print(f"Image resolution: {imgshape[0]}x{imgshape[1]}")
print(f"Number of slices: {imgshape[2]}")

Image resolution: 240x240
Number of slices: 155


Create transformations for data augmentation

In [5]:
from transformations.transformations import brats_validation_transform

brats_transform = brats_validation_transform(image_keys=['t1', 't1ce', 't2', 'flair'], 
                                             all_keys=['t1', 't1ce', 't2', 'flair', 'seg'],)

Load the specific case we want

In [6]:
# MRI modalities
modalities = ['t1', 't1ce', 't2', 'flair'] 

# placeholders 
slice_dict = {}

# loop over all 155 slices
for slice_id in range(155):
    # create dictionary for each slice
    slice_dict[slice_id] = {}

    # get all 4 modalities
    t1 = T1[:, :, slice_id]
    t1ce = T1ce[:, :, slice_id]
    t2 = T2[:, :, slice_id]
    flair = Flair[:, :, slice_id]
    seg_tar = seg_target[:, :, slice_id]

    # create dictionary for each slice 
    slice_dict[slice_id]['t1'] = t1
    slice_dict[slice_id]['t1ce'] = t1ce
    slice_dict[slice_id]['t2'] = t2
    slice_dict[slice_id]['flair'] = flair
    slice_dict[slice_id]['seg'] = seg_tar

    # apply transformations
    slice_dict[slice_id] = brats_transform(slice_dict[slice_id])

Test the loaded case to get the prediction

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
model.eval()

seg_preds = []

n_cases = 2

with torch.no_grad():
    for slice_idx in range(155):
        modality_dict = slice_dict[slice_idx]

        x = torch.cat((modality_dict['flair'], 
                       modality_dict['t1'], 
                       modality_dict['t1ce'], 
                       modality_dict['t2']), dim=0)

        # add batch dimension 1
        x = x.unsqueeze(0)

        # inference
        x = x.to(device)
        out = model(x)
        seg_preds.append(out)

Stack them into shape (155, 4, 240, 240)

In [8]:
# select case
start_slice = 0
end_slice = 155

pred = torch.vstack(seg_preds)
pred = pred[start_slice:end_slice].cpu()
print(np.shape(pred))

torch.Size([155, 4, 240, 240])


Detach from GPU

In [9]:
#seg_output = pred.cpu().detach().numpy()
seg_output = pred
seg_output = torch.argmax(seg_output, dim=1)
seg_output = torch.transpose(seg_output, 0, 2)
seg_output = seg_output.cpu().detach().numpy()
print (seg_output.shape)

(240, 240, 155)


Visualize segmentation prediction vs groundtruth (overlay on top of Flair)

In [10]:
classes_dict = {
    0 : 'B/W = healthy',
    1 : 'Red = necrotic',
    2 : 'Green = edematous',
    3 : 'Blue = enhancing'
}
print(np.shape(seg_target))

# change colours of segmentation result  
color_segmentation = np.zeros((240, 240, 155, 3), dtype=np.uint8)
color_segmentation_pred = np.zeros((240, 240, 155, 3), dtype=np.uint8)


# replace 4 with 3
seg_target[seg_target == 4] = 3
                                                            # Black (healthy tissue) = 0
color_segmentation[seg_target == 1] = [255, 0, 0]    # Red (necrotic tumor core) = 1
color_segmentation[seg_target == 2] = [0, 255, 0]    # Green (peritumoral edematous/invaded tissue) = 2
color_segmentation[seg_target == 3] = [0, 0, 255]    # Blue (enhancing tumor) = 4

                                                                    # Black (healthy tissue) = 0
color_segmentation_pred[seg_output == 1] = [255, 0, 0]  # Red (necrotic tumor core) = 1
color_segmentation_pred[seg_output == 2] = [0, 255, 0]  # Green (peritumoral edematous/invaded tissue) = 2
color_segmentation_pred[seg_output == 3] = [0, 0, 255]  # Blue (enhancing tumor) = 4


def create_seg_figure(background, color_seg, color_seg_pred, slice_idx):

    plt.figure(figsize=(10, 5))

    # prediction
    plt.subplot(1, 2, 1)
    plt.imshow(background, cmap='gray')
    plt.imshow(color_seg_pred, cmap='bone', alpha=0.6)
    plt.title("prediction", fontsize=20)
    plt.axis('off')
    plt.tight_layout()

    # case and slice id
    subtext = f"BRaTS2021_{sample_id}"
    slice_txt = f"{slice_idx:03d}"
    plt.text(45, 230, subtext, fontsize=20, color='white')
    plt.text(10, 30, slice_txt, fontsize=30, color='white')
    
    # target
    plt.subplot(1, 2, 2)
    plt.imshow(background, cmap='gray')
    plt.imshow(color_seg, cmap='bone', alpha=0.6)
    plt.title("target", fontsize=20)
    plt.axis('off')
    plt.tight_layout()
    plt.text(45, 230, subtext, fontsize=20, color='white')
    plt.text(10, 30, slice_txt, fontsize=30, color='white')

    return plt


def visualize_3d_labels(layer):
    color_seg = color_segmentation[:, :, layer, :]
    color_seg_pred = color_segmentation_pred[:, :, layer, :]

    # print segmentation result
    print([classes_dict[int(result)] for result in np.unique(seg_target[:, :, layer])])
 
    background = Flair[:, :, layer]
    plot = create_seg_figure(background, color_seg, color_seg_pred, layer)
    plot.show()

    return layer

interact(visualize_3d_labels, layer=(0, Flair.shape[2] - 1))

(240, 240, 155)


interactive(children=(IntSlider(value=77, description='layer', max=154), Output()), _dom_classes=('widget-inte…

<function __main__.visualize_3d_labels(layer)>

Create WT, TC, ET masks

In [11]:
def whole_tumor_mask(output, target):
    """ Dice coeffecient for whole tumor (union of classes 1-3)
        See: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#Dice-Loss
    Args:
        output (torch.Tensor): model output probalities between [0-1], shape (N, C, H, W)
        target (torch.Tensor): target one-hot encoded, shape (N, C, H, W)
        eps (int): prevent division by 0
    Returns:
        torch.Tensor: dice coeff. Shape: (1,)
    """
    with torch.no_grad():
        print("output shape: ", output.shape)
        print("target shape: ", target.shape)

        # for the target we want to take the union of all the tumor classes
        # target_union = torch.sum(target[:, 1:], dim=1).clip(0, 1).unsqueeze(1).cpu()
        # print("target_union", target_union)
        target_union = target.clip(0, 1).cpu()

        # for the output we want to take the argmax of all the classes
        # output_union = torch.argmax(output, dim=1).clip(0, 1).unsqueeze(1).cpu()
        # print("output_union ", output_union)
        output_union = output.clip(0, 1).cpu()

    return output_union, target_union

Visualize masks

In [12]:
# print(seg_target.shape)

wt_out, wt_target = whole_tumor_mask(torch.from_numpy(seg_output), torch.from_numpy(seg_target))

print("wt_out: ", wt_out.shape)
print("wt_target: ", wt_target.shape)

def visualize_masks(layer):
    color_seg_wt = wt_target[:, :, layer]
    color_seg_pred_wt = wt_out[:, :, layer]
    print(color_seg_wt.shape)
    print(color_seg_pred_wt.shape)

    # print segmentation result
    print([classes_dict[int(result)] for result in np.unique(seg_target[:, :, layer])])
 
    background = Flair[:, :, layer]
    plot = create_seg_figure(background, color_seg_wt, color_seg_pred_wt, layer)
    plot.show()

    return layer

interact(visualize_masks, layer=(0, Flair.shape[2] - 1))

output shape:  torch.Size([240, 240, 155])
target shape:  torch.Size([240, 240, 155])
wt_out:  torch.Size([240, 240, 155])
wt_target:  torch.Size([240, 240, 155])


interactive(children=(IntSlider(value=77, description='layer', max=154), Output()), _dom_classes=('widget-inte…

<function __main__.visualize_masks(layer)>

Compute dice coefficients

In [13]:
from monai.metrics import compute_generalized_dice, compute_meandice

wt_out = torch.transpose(wt_out, 0, 2)
print("wt_out: ", wt_out.shape)
wt_target = torch.transpose(wt_target, 0, 2)
print("wt_target: ", wt_target.shape)


wt_out:  torch.Size([155, 240, 240])
wt_target:  torch.Size([155, 240, 240])


In [50]:
wt_out_orig = wt_out.clone()
wt_target_orig = wt_target.clone()

# remove all slices with no tumor
print("wt_out: ", wt_out.shape)
print("wt_target: ", wt_target.shape)

# take only slices that have tumor in prediction or target
wt_out_indices = torch.where(wt_out.sum(dim=(1,2)) > 0)
wt_target_indices = torch.where(wt_target.sum(dim=(1,2)) > 0)

min_slice = min(wt_out_indices[0].min(), wt_target_indices[0].min())
max_slice = max(wt_out_indices[0].max(), wt_target_indices[0].max())

print("min_slice:max_slice = {}:{}".format(min_slice, max_slice))

# remove all slices with no tumor
wt_out_sliced = wt_out[min_slice:max_slice + 1]
wt_target_sliced = wt_target[min_slice:max_slice + 1]

print("wt_out_sliced: ", wt_out_sliced.shape)
print("wt_target_sliced: ", wt_target_sliced.shape)

# unsqueeze two times, one for batchsize=1 and one for channel=1
wt_out_sliced = wt_out_sliced.unsqueeze(0).unsqueeze(0)
wt_target_sliced = wt_target_sliced.unsqueeze(0).unsqueeze(0)

print("wt_out_sliced: ", wt_out_sliced.shape)
print("wt_target_sliced: ", wt_target_sliced.shape)

dice_score_wt = torch.as_tensor(compute_meandice(wt_out_sliced, wt_target_sliced))

# print("dice_score_wt shape: ", dice_score_wt.shape)
print("dice_score_wt: ", torch.mean(dice_score_wt))
# print("dice_score_wt: ", dice_score_wt)

wt_out:  torch.Size([155, 240, 240])
wt_target:  torch.Size([155, 240, 240])
min_slice:max_slice = 32:90
wt_out_sliced:  torch.Size([59, 240, 240])
wt_target_sliced:  torch.Size([59, 240, 240])
wt_out_sliced:  torch.Size([1, 1, 59, 240, 240])
wt_target_sliced:  torch.Size([1, 1, 59, 240, 240])
dice_score_wt:  tensor(0.6678)
