In [1]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import gzip
import os
import cv2
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
import torch
import torch.nn.functional as F

import os
import sys

# for relative imports to work in notebooks
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from model.model  import BraTS2021BaseUnetModel

Import pre-trained model

In [2]:
from utils import load_model
# best model is located at: python test.py -r saved/models/BraTS2021_Base_Unet/0523_152701/model_best.pth
# model = load_model(path = '../saved/models/BraTS2021_Base_Unet/0523_152701/model_best.pth')

# print(model)

# # create network object from saved\models\BraTS2021_Base_Unet\0523_152701
# config = ConfigParser.from_args(args)
# model = config.init_obj('arch', module_arch)

# # unet_model = BraTS2021BaseUnetModel()

# # execute from pre-trained model
# # unet_model.load_state_dict(torch.load('../model/unet_1epoch.pth'))

# # checkpoint = torch.load('../model/unet_1epoch.pth')
# # state_dict = checkpoint['state_dict']
# # unet_model = torch.nn.DataParallel(unet_model)
# # unet_model.load_state_dict(state_dict)

# print(unet_model)

Define dice coefficient for class total

In [3]:


def dice_coeff_enhancing_tumor(output, target):
    """ Dice coeffecient for enhancing tumor(class 3)
        See: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#Dice-Loss
    Args:
        output (torch.Tensor): model output probalities between [0-1], shape (N, C, H, W)
        target (torch.Tensor): target one-hot encoded, shape (N, C, H, W)
        eps (int): prevent division by 0
    Returns:
        torch.Tensor: dice coeff. Shape: (1,)
    """
    with torch.no_grad():
        # take tumor class 3
        target = target[:, [3]]
        output = output[:, [3]]

        dice_coeff_en = dice_coeff(output, target, include_background=False)
    return dice_coeff_en


def dice_coeff_tumor_core(output, target):
    """ Dice coeffecient for tumor core (union of classes 1+3)
        See: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#Dice-Loss
    Args:
        output (torch.Tensor): model output probalities between [0-1], shape (N, C, H, W)
        target (torch.Tensor): target one-hot encoded, shape (N, C, H, W)
        eps (int): prevent division by 0
    Returns:
        torch.Tensor: dice coeff. Shape: (1,)
    """
    with torch.no_grad():
        # take tumor classes 1 and 3
        target = target[:, [1, 3]]
        output = output[:, [1, 3]]

        dice_coeff_tc = dice_coeff(output, target, include_background=False)
    return dice_coeff_tc



def dice_coeff_healthy(output, target):
    """ Dice coeffecient for background/healthy class (0)
        See: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#Dice-Loss
    Args:
        output (torch.Tensor): model output probalities between [0-1], shape (N, C, H, W)
        target (torch.Tensor): target one-hot encoded, shape (N, C, H, W)
        eps (int): prevent division by 0
    Returns:
        torch.Tensor: dice coeff. Shape: (1,)
    """
    with torch.no_grad():
        # take only the background/healthy class (0)
        target = target[:, :1]
        output = output[:, :1]

        dice_coef_healthy = dice_coeff(output, target)
    return dice_coef_healthy


def dice_coeff(output, target, reduce_class=True, reduce_batch=True, 
               eps=1e-7, smooth_nr=0, smooth_dr=0, include_background=True, backprop=False):
    """ Dice coeffecient
        See: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#Dice-Loss  

    Args:
        output (torch.Tensor): model output probalities between [0-1], shape (N, C, H, W)
        target (torch.Tensor): target one-hot encoded, shape (N, C, H, W)
        reduce_class (bool): if True, sum over class dimension
        reduce_batch (bool): if True, sum over batch dimension
        eps (int): prevent division by 0
        smooth_nr (float): smooth factor, numerator
        smooth_dr (float): smooth factor, denominator
        include_background (bool): if True, include background class

    Returns:
        torch.Tensor: dice coeff. Shape: (1,) if reduce_class&reduce_batch; 
                                         (C,) if reduce_batch; 
                                         (N,) if reduce_class; 
                                         (N, C) otherwise (default) .requires_grad_(False)
    """
    dim = [2, 3] 

    if reduce_class:
        dim = [1] + dim

    if reduce_batch:
        dim = [0] + dim

    if not include_background:
        target = target[:, 1:]
        output = output[:, 1:]

    # contiguous() ensures that the memory is not reallocated during the operation
    output = output.contiguous()
    target = target.contiguous()   

    if not backprop:
        with torch.no_grad():
            intersection = torch.sum(torch.mul(output, target), dim=dim)
            abs_area = torch.sum(output, dim=dim) + torch.sum(target, dim=dim) + eps # abs_area = abs_output + abs_target

            # dice = 2 * (A U B) / (|A| + |B|)
            dice_coeff = (2. * intersection + smooth_nr) / (abs_area + smooth_dr)
    
    else:
        intersection = torch.sum(torch.mul(output, target), dim=dim)
        abs_area = torch.sum(output, dim=dim) + torch.sum(target, dim=dim) + eps # abs_area = abs_output + abs_target

        # dice = 2 * (A U B) / (|A| + |B|)
        dice_coeff = (2. * intersection + smooth_nr) / (abs_area + smooth_dr)        

    return dice_coeff  

Generate dummy data

In [5]:
# random data with shape (N, C, H, W) =  torch.Size([N, 4, 240, 240])
C = 4
N = 8
H = 240
W = 240

output = torch.randint(0, 3, (N, C, H, W), dtype=torch.float32)
output = F.softmax(output, dim=1)

# target
target = torch.randint(0, C, (N, H, W), dtype=torch.int64)
target = F.one_hot(target, num_classes=C)
target = torch.transpose(target, 1, 3)

# test output
test_output = torch.randint(0, C, (N, H, W), dtype=torch.int64)
test_output = F.one_hot(test_output, num_classes=C)
test_output = torch.transpose(test_output, 1, 3)

print("target.shape: ", target.shape)
print("output.shape: ", output.shape)

# Semantics:
#     Whole tumor: 		WT: 	Union of all tumor labels (1) + (2) + (4)
#     Tumor core: 		TC: 	Gross tumor core outline --> Union of labels ET(4) + necrosis(1)
#     Enhancing tumor: 	ET: 	Only ET(4)      

def dice_coeff_whole_tumor(output, target):
    """ Dice coeffecient for whole tumor (union of classes 1-3)
        See: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#Dice-Loss
    Args:
        output (torch.Tensor): model output probalities between [0-1], shape (N, C, H, W)
        target (torch.Tensor): target one-hot encoded, shape (N, C, H, W)
        eps (int): prevent division by 0
    Returns:
        torch.Tensor: dice coeff. Shape: (1,)
    """
    with torch.no_grad():
        # take all the tumor classes
        target = target[:, [1, 3]]
        output = output[:, [1, 3]]
        print("Target shape: ", target.shape)

        target_union = torch.sum(target, dim=1).unsqueeze(1).clip(0, 1)
        output_union = torch.argmax(output, dim=1).unsqueeze(1).clip(0, 1)



        # dice_coeff_wt = dice_coeff(target_union, target_union, include_background=False, reduce_class=True)
        dice_coeff_wt = dice_coeff(target_union, target_union, include_background=True, reduce_class=True)
    return dice_coeff_wt

dice_coef_wt = dice_coeff_whole_tumor(output, output)
print("dice_coef_wt: ", dice_coef_wt)
print("output mean: ", torch.mean(output))


target.shape:  torch.Size([8, 4, 240, 240])
output.shape:  torch.Size([8, 4, 240, 240])
Target shape:  torch.Size([8, 2, 240, 240])
dice_coef_wt:  tensor(0.5691)
