In [1]:
import glob
import numpy as np
import nibabel as nib

In [32]:
def assert_shape(test, reference):

    assert test.shape == reference.shape, "Shape mismatch: {} and {}".format(
        test.shape, reference.shape)
    
    
class ConfusionMatrix:

    def __init__(self, test=None, reference=None):

        self.tp = None
        self.fp = None
        self.tn = None
        self.fn = None
        self.size = None
        self.reference_empty = None
        self.reference_full = None
        self.test_empty = None
        self.test_full = None
        self.set_reference(reference)
        self.set_test(test)

    def set_test(self, test):

        self.test = test
        self.reset()

    def set_reference(self, reference):

        self.reference = reference
        self.reset()

    def reset(self):

        self.tp = None
        self.fp = None
        self.tn = None
        self.fn = None
        self.size = None
        self.test_empty = None
        self.test_full = None
        self.reference_empty = None
        self.reference_full = None

    def compute(self):

        if self.test is None or self.reference is None:
            raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.")

        assert_shape(self.test, self.reference)

        self.tp = int(((self.test != 0) * (self.reference != 0)).sum())
        self.fp = int(((self.test != 0) * (self.reference == 0)).sum())
        self.tn = int(((self.test == 0) * (self.reference == 0)).sum())
        self.fn = int(((self.test == 0) * (self.reference != 0)).sum())
        self.size = int(np.prod(self.reference.shape, dtype=np.int64))
        self.test_empty = not np.any(self.test)
        self.test_full = np.all(self.test)
        self.reference_empty = not np.any(self.reference)
        self.reference_full = np.all(self.reference)

    def get_matrix(self):

        for entry in (self.tp, self.fp, self.tn, self.fn):
            if entry is None:
                self.compute()
                break

        return self.tp, self.fp, self.tn, self.fn

    def get_size(self):

        if self.size is None:
            self.compute()
        return self.size

    def get_existence(self):

        for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):
            if case is None:
                self.compute()
                break

        return self.test_empty, self.test_full, self.reference_empty, self.reference_full


def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """2TP / (2TP + FP + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty and reference_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0.

    return float(2. * tp / (2 * tp + fp + fn))

In [6]:
# test for one subject
import SimpleITK as sitk
base = '/Users/mona/Library/CloudStorage/Box-Box/Animals/Cinnemon/Cinnemon_D6/registration/'
gt_path = f'{base}/POSTCONT1_endo_epi.nii'

MOCO_stage1_path = f'{base}/register_1_endo_epi.nii'
MOCO_stage2_path = f'{base}/register_2_endo_epi.nii'

gt = sitk.GetArrayFromImage(sitk.ReadImage(gt_path)).transpose(0, 2, 1)
MOCO_stage1 = sitk.GetArrayFromImage(sitk.ReadImage(MOCO_stage1_path)).transpose(0, 2, 1)
MOCO_stage2 = sitk.GetArrayFromImage(sitk.ReadImage(MOCO_stage2_path)).transpose(0, 2, 1)

print(f"Ground truth shape: {gt.shape}, MOCO_stage1 shape: {MOCO_stage1.shape}, MOCO_stage2 shape: {MOCO_stage2.shape}")

Ground truth shape: (96, 96, 17), MOCO_stage1 shape: (96, 96, 17), MOCO_stage2 shape: (96, 96, 17)


In [7]:
sitk.Show(sitk.GetImageFromArray(gt.transpose(2, 0, 1)))
sitk.Show(sitk.GetImageFromArray(MOCO_stage1.transpose(2, 0, 1)))
sitk.Show(sitk.GetImageFromArray(MOCO_stage2.transpose(2, 0, 1)))

In [31]:
# def dice_score(y_true, y_pred):
#     print(y_pred.shape, y_true.shape)
#     ndims = len(list(y_pred.shape)) - 2
#     vol_axes = list(range(2, ndims + 2))
#     top = np.sum(2 * (y_true * y_pred), keepdims=vol_axes)
#     # top = 2 * (y_true * y_pred).sum(dim=vol_axes)
#     bottom = np.clip(np.sum(y_true + y_pred, keepdims=vol_axes), min=1e-5)
#     # bottom = torch.clamp((y_true + y_pred).sum(dim=vol_axes), min=1e-5)
#     # dice = torch.mean(top / bottom)
#     dice = np.mean(top / bottom)
#     return -dice


def dice_score_slice(data):
    num_slices = data.shape[-1]
    sum_dice = 0
    template = round(num_slices/2)
    for i in range(num_slices):
        if i != template:
            sum_dice += dice(test=data[:,:,i], reference=data[:,:,template])
    return sum_dice / num_slices

In [17]:
sitk.Show(sitk.GetImageFromArray(LV_myo_gt.transpose(2, 0, 1)))

In [36]:
import copy
LV_myo_gt = copy.deepcopy(gt)
LV_myo_gt[LV_myo_gt==1] = 0
LV_myo_gt[LV_myo_gt==2] = 1
LV_myo_stage1 = copy.deepcopy(MOCO_stage1)
LV_myo_stage1[LV_myo_stage1 ==1] = 0
LV_myo_stage1[LV_myo_stage1 ==2] = 1
LV_myo_stage2 = copy.deepcopy(MOCO_stage2)
LV_myo_stage2[LV_myo_stage2 == 1] = 0
LV_myo_stage2[LV_myo_stage2 == 2] = 1

dice_gt = dice_score_slice(LV_myo_gt)
dice_stage1 = dice_score_slice(LV_myo_stage1)
dice_stage2 = dice_score_slice(LV_myo_stage2)
print(f"LV and myo Dice for groundtruth: {dice_gt}, Dice after first stage: {dice_stage1}, Dice after second stage {dice_stage2}")

LV and myo Dice for groundtruth: 0.9007229967845952, Dice after first stage: 0.9088718138523625, Dice after second stage 0.8965059279709937


In [35]:
import copy
myo_gt = copy.deepcopy(gt)
myo_gt[myo_gt>1] = 0
myo_stage1 = copy.deepcopy(MOCO_stage1)
myo_stage1[myo_stage1 > 1] = 0
myo_stage2 = copy.deepcopy(MOCO_stage2)
myo_stage2[myo_stage2 > 1] = 0

dice_gt = dice_score_slice(myo_gt)
dice_stage1 = dice_score_slice(myo_stage1)
dice_stage2 = dice_score_slice(myo_stage2)
print(f"Myo Dice for groundtruth: {dice_gt}, Dice after first stage: {dice_stage1}, Dice after second stage {dice_stage2}")


Myo Dice for groundtruth: 0.8917304973286084, Dice after first stage: 0.9020859799846722, Dice after second stage 0.8871082419211254
