In [None]:
import numpy as np
from medpy import metric

def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0 and gt.sum()>0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    elif pred.sum() > 0 and gt.sum()==0:
        return 1, 0
    else:
        return 0, 0

def hd95_multi_class(pred, gt, num_classes):
    """
    Args:
        pred: shape (H, W, D), each element is an integer between [0, num_classes)
        label: shape (H, W, D), each element is an integer between [0, num_classes)
        num_classes: integer, the number of categories including bg
        
    Returns:
        ret: shape (num_classes - 1), the i-th element is the HD95 of class (i+1)
    """
    ret = [(-1, -1)] # [-1] for background
    for i in range(1, num_classes):
        ret.append(calculate_metric_percase(pred == i, gt == i))
    return np.array(ret)


import pickle

# read ours.pkl using unpickler
f = pickle.Unpickler(open("ours.pkl", "rb"))
data = []
while True:
    try:
        data.append(f.load())
    except EOFError:
        break


import nibabel

val_slc = [1, 2, 3, 4, 8, 22, 25, 29, 32, 35, 36, 38]
n_slc = [] 
spacings = []
for idx in val_slc:
    label = nibabel.load(f"data/MedicalImages/bcv30/RawData/Training/label/label{idx:04d}.nii.gz").get_fdata()
    n_slc.append(label.shape[-1])
    spacing = nibabel.load(f"data/MedicalImages/bcv30/RawData/Training/label/label{idx:04d}.nii.gz").header.get_zooms()
    spacings.append(spacing)

patient_preds = []
patient_labels = []
cum_idx = 0
for ns in n_slc:
    preds = np.stack([data[idx]["pred"] for idx in range(cum_idx, cum_idx + ns)]).transpose(1, 2, 0)
    labels = np.stack([data[idx]["label"] for idx in range(cum_idx, cum_idx + ns)]).transpose(1, 2, 0)
    patient_preds.append(preds)
    patient_labels.append(labels)
    cum_idx += ns

In [None]:
# Verify DSC computation correctness

large_preds = np.concatenate(patient_preds, axis=-1)
large_labels = np.concatenate(patient_labels, axis=-1)

dscs = [metric.binary.dc(large_preds == i, large_labels == i) for i in range(14)]
print(np.array(dscs)[[8, 4, 3, 2, 6, 11, 1, 7]].mean())

In [None]:
# Statistics of organs (number of voxels, number of slices, slice variantions)

for i in [8, 4, 3, 2, 6, 11, 1, 7]:
    print((large_labels == i).sum())

for i in [8, 4, 3, 2, 6, 11, 1, 7]:
    cnts = []
    for j in range(large_labels.shape[-1]):
        if (large_labels[:, :, j] == i).sum() > 0:
            cnts.append((large_labels[:, :, j] == i).sum())
    print(len(cnts), np.std(cnts))

for i in [8, 4, 3, 2, 6, 11, 1, 7]:
    cnts = []
    for j in range(large_labels.shape[-1] - 1):
        cur_voxel_cnt = (large_labels[:, :, j] == i).sum()
        # next_voxel_cnt = (large_labels[:, :, j + 1] == i).sum()
        if cur_voxel_cnt > 0: #  and next_voxel_cnt > 0:
            dsc = metric.binary.dc(large_labels[:, :, j] == i, large_labels[:, :, j + 1] == i)
            cnts.append(dsc)
    print(len(cnts), np.mean(cnts))

In [None]:
# Compute HD95 for each organ

metrics = []
import cv2

def downsample(x, resolution):
    x = cv2.resize(x, (resolution[0], resolution[1]), interpolation=cv2.INTER_NEAREST)
    return x

for i in range(len(patient_preds)):
    assert spacings[i][0] == spacings[i][1]
    ratio = spacings[i][2] / spacings[i][1]
    size = int(512 / ratio)
    preds = downsample(patient_preds[i], (size, size))
    labels = downsample(patient_labels[i], (size, size))
    print(preds.shape, spacings[i][2])

    dsc, hd95 = zip(*hd95_multi_class(preds, labels, num_classes=14).tolist())

    metrics.append((np.array(dsc), np.array(hd95) * spacings[i][2]))

print(np.array(metrics)[:, 1, :].mean(axis=0)[[8, 4, 3, 2, 6, 11, 1, 7]])
print(np.array(metrics)[:, 1, :].mean(axis=0)[[8, 4, 3, 2, 6, 11, 1, 7]].mean()) # HD95
print(np.array(metrics)[:, 0, :].mean(axis=0)[[8, 4, 3, 2, 6, 11, 1, 7]].mean()) # DSC