In [None]:
import os
import sys
import numpy as np
import scipy
import torch
import monai

from matplotlib import pyplot as plt

sys.path.append("/mnt/data/mranzini/Desktop/GIFT-Surg/FBS_Monai/basic_unet_monai/src")
from custom_losses import DiceLossExtended, DiceLoss_noSmooth

In [None]:
def compute_soft_dice(seg1, seg2, label=1.0):
    # flatten the segmentations
    seg1 = seg1.flatten()
    seg2 = seg2.flatten()

    if len(seg1) != len(seg2):
        raise InputError('The two segmentations have different dimensions - not comparable!')

#     yt = np.asarray(seg1 == label, np.float32)
#     yp = np.asarray(seg2 == label, np.float32)
    yt = seg1
    yp = seg2
    
    intersection = np.multiply(yt, yp)
    union = np.asarray(yt + yp > 0, np.float32)
    dice = 2 * np.sum(intersection) / (np.sum(yt) + np.sum(yp) + 1e-10)
    
    return dice



In [None]:
def generate_gt_pred_pairs(test_num):
    tests_dict = {0: "Empty prediction", 1: "Perfect match", 2: "Smooth big circle", 3: "Shifted small circle"}

    # create a 2D fake examples
    size_x, size_y = 96, 96
    cx1, cy1 = 74, 62
    r1 = 32
    cx2, cy2 = 12, 12
    r2 = 8

    ground_truth = np.zeros([size_x, size_y])
    # add first circle
    y1, x1 = np.ogrid[-cx1:size_x-cx1, -cy1:size_y-cy1]
    mask1 = x1*x1 + y1*y1 <= r1*r1
    ground_truth[mask1] = 1.0
    # add second circle
    y2, x2 = np.ogrid[-cx2:size_x-cx2, -cy2:size_y-cy2]
    mask2 = x2*x2 + y2*y2 <= r2*r2
    ground_truth[mask2] = 1.0

    net_out = np.zeros([size_x, size_y])
    if test_num == 0:
        pass
    elif test_num == 1:
        net_out = ground_truth
    elif test_num == 2:
        net_out[mask1] = 1.0
        net_out = scipy.ndimage.gaussian_filter(net_out, sigma=3)
    elif test_num == 3:
        shifted_cx2 = cx2+1
        shifted_cy2 = cy2-1
        y2bis, x2bis = np.ogrid[-shifted_cx2:size_x-shifted_cx2, -shifted_cy2:size_y-shifted_cy2]
        mask2bis = x2bis*x2bis + y2bis*y2bis <= r2*r2
        net_out[mask2bis] = 1.0
    else:
        raise Exception("Unidentified test number")

    fig, ax = plt.subplots(nrows=1, ncols=2)
    ax[0].imshow(ground_truth, vmin=0.0, vmax=1.0, interpolation='nearest')
    ax[1].imshow(net_out, vmin=0.0, vmax=1.0, interpolation='nearest')
    
    return ground_truth, net_out, tests_dict

## Validate DiceLossExtended vs monai.losses.DiceLoss

In [None]:
# generate the test data
test_n = 1
ground_truth1, net_out1, tests_dict = generate_gt_pred_pairs(test_num=test_n)

do_sigmoid, do_softmax = False, False
smooth_num = 1e-5
smooth_den = 1e-5

dice_numpy = compute_soft_dice(ground_truth1, net_out1)

# convert to tensor and use the losses in MONAI
gt_tensor = torch.as_tensor(np.ascontiguousarray(ground_truth1))
gt_tensor = gt_tensor[None, None]
out_tensor = torch.as_tensor(np.ascontiguousarray(net_out1))
out_tensor = out_tensor[None, None]

# verify that Expanded returns the same as monai.losses.Dice when same setting
monai_dice_loss = monai.losses.DiceLoss(sigmoid=do_sigmoid, softmax=do_softmax)
monai_dice = monai_dice_loss(out_tensor, gt_tensor, smooth=smooth_num)

extended_dice_loss = DiceLossExtended(sigmoid=do_sigmoid, softmax=do_softmax, batch_version=False, smooth_num=smooth_num, smooth_den=smooth_den)
extended_dice = extended_dice_loss(out_tensor, gt_tensor)

print(f"*** Test = {tests_dict[test_n]} ***\n")
print(f"Dice loss from numpy: {1.0-dice_numpy}")
print(f"Dice loss from monai.losses.DiceLoss: {monai_dice}")
print(f"Dice loss from DiceLossExtended: {extended_dice}\n")

## Test Batch Dice implementation

In [None]:
# create two pairs gt-prediction
gt1, out1, tests_dict = generate_gt_pred_pairs(test_num=0)
gt2, out2, _ = generate_gt_pred_pairs(test_num=1)

# convert them to 4D arrays [Batch x Channel x Height x Width]
def create_batch(arr1, arr2):
    arr1 = np.expand_dims(arr1, axis=0)
    arr1 = np.expand_dims(arr1, axis=0)
    arr2 = np.expand_dims(arr2, axis=0)
    arr2 = np.expand_dims(arr2, axis=0)
    return np.concatenate((arr1, arr2), axis=0)

full_gt = create_batch(gt1, gt2)
full_out = create_batch(out1, out2)

In [None]:
do_sigmoid, do_softmax = False, False
smooth_num = 1e-5
smooth_den = 1e-5

batch_dice_numpy = compute_soft_dice(full_gt, full_out)
avg_dice_numpy = 0.5 * (compute_soft_dice(gt1, out1) + compute_soft_dice(gt2, out2))

# convert to tensor and use the losses in MONAI
gt_tensor = torch.as_tensor(np.ascontiguousarray(full_gt))
out_tensor = torch.as_tensor(np.ascontiguousarray(full_out))

# verify that Expanded returns the same as monai.losses.Dice when same setting
monai_dice_loss = monai.losses.DiceLoss(sigmoid=do_sigmoid, softmax=do_softmax)
monai_dice = monai_dice_loss(out_tensor, gt_tensor, smooth=smooth_num)

batch_dice_loss = DiceLossExtended(sigmoid=do_sigmoid, softmax=do_softmax, batch_version=True, smooth_num=smooth_num, smooth_den=smooth_den)
batch_dice = batch_dice_loss(out_tensor, gt_tensor)

print(f"*** Test = {tests_dict[test_n]} ***\n")
print(f"Avg Dice loss from numpy: {1.0-avg_dice_numpy}")
print(f"Avg Dice loss from monai.losses.DiceLoss: {monai_dice}")
print(f"Batch Dice loss from numpy: {1.0-batch_dice_numpy}")
print(f"Batch Dice loss from DiceLossExtended: {batch_dice}\n")