In [1]:
import nibabel as nib
import numpy as np

In [2]:
# Extracted information from original nifti files. 
# DeepReg throws away this information and doesn't acurately calculate mTRE

case1_affine = np.array([[ 3.90691876e-01,  1.30452037e-01,  2.81675309e-01, -1.96299267e+01],
       [ 1.55749843e-02,  4.42654282e-01, -2.28414789e-01, -5.37206192e+01],
       [-3.09405476e-01,  1.87006667e-01,  3.44178468e-01, 2.05047684e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case2_affine = np.array([[ 3.33370864e-01,  1.59845248e-01,  3.33878189e-01, -3.39173775e+01],
       [-1.82626724e-01,  4.61422384e-01, -3.86130475e-02, 4.38701019e+01],
       [-3.21150362e-01, -9.64666754e-02,  3.68540883e-01, 6.26287537e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case12_affine = np.array([[  0.41881981,   0.07436115,   0.25797391, -57.30677795],
       [  0.06075069,   0.43993044,  -0.22682622, -23.42017174],
       [ -0.26014873,   0.22244968,   0.36234942, -19.91741753],
       [  0.        ,   0.        ,   0.        ,   1.        ]])

case13_affine = np.array([[ 4.08450842e-01, -1.18215449e-01,  2.58995205e-01, -2.38458767e+01],
       [ 1.57638997e-01,  4.71258372e-01, -3.36058326e-02, -7.20230942e+01],
       [-2.37171665e-01,  1.09639972e-01,  4.23698217e-01, -5.00180817e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case21_affine = np.array([[4.00647670e-01, -1.04107566e-01,  2.77501583e-01, -3.58118629e+01],
       [ 1.67786613e-01,  4.64567333e-01, -6.78447485e-02, -8.16707916e+01],
       [-2.44675279e-01,  1.48105368e-01,  4.07874972e-01, -8.83094406e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case23_affine = np.array([[  0.31273046,   0.25455737,   0.29349992, -41.70222473],
       [ -0.28844342,   0.40450525,  -0.04210297,   2.00646901],
       [ -0.2584393 ,  -0.14343421,   0.40214738,   4.35490799],
       [  0.        ,   0.        ,   0.        ,   1.        ]])

DeepReg_rescale = (151, 136, 119)

case2_size = (128, 161, 106)
case2_scale = np.array([case2_size[0] / DeepReg_rescale[0], 
                        case2_size[1] / DeepReg_rescale[1], 
                        case2_size[2] / DeepReg_rescale[2]])

case21_size = (215, 251, 191)
case21_scale = np.array([case21_size[0] / DeepReg_rescale[0], 
                         case21_size[1] / DeepReg_rescale[1], 
                         case21_size[2] / DeepReg_rescale[2]])

case23_size = (106, 147, 134)
case23_scale = np.array([case23_size[0] / DeepReg_rescale[0], 
                         case23_size[1] / DeepReg_rescale[1], 
                         case23_size[2] / DeepReg_rescale[2]])

In [3]:
def extract_centroid(image):
    """
    Extract centroid from nifti images with landmark spheres
    which have integer values according to labels
    Adapted from: https://gist.github.com/mattiaspaul/f4183f525b1cbc65e71ad23298d6436e

    :param image:
        - shape: (dim_1, dim_2, dim_3) or (batch, dim_1, dim_2, dim_3)
        - tensor or numpy array

    :return positions:
        - numpy array of labels 1
    """
    assert len(image.shape) == 3

    x = np.linspace(0, image.shape[0] - 1, image.shape[0])
    y = np.linspace(0, image.shape[1] - 1, image.shape[1])
    z = np.linspace(0, image.shape[2] - 1, image.shape[2])
    yv, xv, zv = np.meshgrid(y, x, z)
    unique = np.unique(image)[1:]  # don't include 0
    positions = np.zeros((len(unique), 3))
    for i in range(len(unique)):
        label = (image == unique[i]).astype('float32')
        xc = np.sum(label * xv) / np.sum(label)
        yc = np.sum(label * yv) / np.sum(label)
        zc = np.sum(label * zv) / np.sum(label)
        positions[i, 0] = xc
        positions[i, 1] = yc
        positions[i, 2] = zc
    return positions

In [4]:
def calculate_mTRE(xyz_true, xyz_predict):
    assert xyz_true.shape == xyz_predict.shape
    TRE = np.sqrt(np.sum(np.power(xyz_true - xyz_predict, 2), axis=1))
    mTRE = np.mean(TRE)
    return mTRE

In [5]:
def case_TREs(pred_dir, pair_number, num_labels, affine, scale):
    TREs = np.zeros(num_labels)
    for i in range(num_labels):
        label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/fixed_label.nii.gz")
        pred_label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/pred_fixed_label.nii.gz")

        label_np = label.get_fdata()
        label_affine = label.affine

        pred_label_np = pred_label.get_fdata()
        pred_label_affine = pred_label.affine
        
        label_point = nib.affines.apply_affine(affine, extract_centroid(np.round(label_np))*scale)
        pred_point = nib.affines.apply_affine(affine, extract_centroid(np.round(pred_label_np))*scale)
        
        TREs[i] = calculate_mTRE(label_point, pred_point)
        
    return TREs

# Calulating the mTRE for the 3 test cases

In [18]:
# folder path to the prediction data
prediction_dir = "logs/20210413-124713/test/"

## Case 2

In [19]:
case2_TREs = case_TREs(prediction_dir, 0, 15, case2_affine, case2_scale)
case2_TREs

array([5.02955917, 4.39510395, 4.40802551, 5.00299152, 4.47532721,
       5.00878451, 2.57974281, 5.40922841, 6.8145309 , 5.04973411,
       3.94488403, 6.7973698 , 6.69415665, 9.15325141, 4.45252542])

In [20]:
case2_mTRE = np.mean(case2_TREs)
print(f"The mTRE for case 2 was {case2_mTRE}")

The mTRE for case 2 was 5.281014361060188


## Case 21

In [21]:
case21_TREs = case_TREs(prediction_dir, 1, 16, case21_affine, case21_scale)
case21_TREs

array([6.09522505, 3.75540878, 5.29771185, 3.74406199, 4.67369326,
       4.47779977, 4.072249  , 5.12406763, 4.99874554, 3.58955554,
       5.36138275, 4.25330097, 4.21532132, 3.8714921 , 3.91247806,
       3.86929028])

In [22]:
case21_mTRE = np.mean(case21_TREs)
print(f"The mTRE for case 21 was {case21_mTRE}")

The mTRE for case 21 was 4.456986492822626


## Case 23

In [23]:
case23_TREs = case_TREs(prediction_dir, 2, 15, case23_affine, case23_scale)
case23_TREs

array([6.57246287, 6.88693783, 7.87822869, 6.49209298, 7.42894543,
       8.2068717 , 6.58697157, 8.14573094, 8.14917004, 7.88881776,
       5.19166396, 5.99010256, 5.29910762, 5.9052913 , 8.21326861])

In [24]:
case23_mTRE = np.mean(case23_TREs)
print(f"The mTRE for case 23 was {case23_mTRE}")

The mTRE for case 23 was 6.989044259099567
