In [1]:
import nibabel as nib
import numpy as np

In [16]:
# Extracted information from original nifti files. 
# DeepReg throws away this information and doesn't acurately calculate mTRE

case2_affine = np.array([[ 3.33370864e-01,  1.59845248e-01,  3.33878189e-01, -3.39173775e+01],
       [-1.82626724e-01,  4.61422384e-01, -3.86130475e-02, 4.38701019e+01],
       [-3.21150362e-01, -9.64666754e-02,  3.68540883e-01, 6.26287537e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case21_affine = np.array([[4.00647670e-01, -1.04107566e-01,  2.77501583e-01, -3.58118629e+01],
       [ 1.67786613e-01,  4.64567333e-01, -6.78447485e-02, -8.16707916e+01],
       [-2.44675279e-01,  1.48105368e-01,  4.07874972e-01, -8.83094406e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case23_affine = np.array([[  0.31273046,   0.25455737,   0.29349992, -41.70222473],
       [ -0.28844342,   0.40450525,  -0.04210297,   2.00646901],
       [ -0.2584393 ,  -0.14343421,   0.40214738,   4.35490799],
       [  0.        ,   0.        ,   0.        ,   1.        ]])

DeepReg_rescale = (151, 136, 119)

case2_size = (128, 161, 106)
case2_scale = np.array([case2_size[0] / DeepReg_rescale[0], 
                        case2_size[1] / DeepReg_rescale[1], 
                        case2_size[2] / DeepReg_rescale[2]])

case21_size = (215, 251, 191)
case21_scale = np.array([case21_size[0] / DeepReg_rescale[0], 
                         case21_size[1] / DeepReg_rescale[1], 
                         case21_size[2] / DeepReg_rescale[2]])

case23_size = (106, 147, 134)
case23_scale = np.array([case23_size[0] / DeepReg_rescale[0], 
                         case23_size[1] / DeepReg_rescale[1], 
                         case23_size[2] / DeepReg_rescale[2]])

In [3]:
def extract_centroid(image):
    """
    Extract centroid from nifti images with landmark spheres
    which have integer values according to labels
    Adapted from: https://gist.github.com/mattiaspaul/f4183f525b1cbc65e71ad23298d6436e

    :param image:
        - shape: (dim_1, dim_2, dim_3) or (batch, dim_1, dim_2, dim_3)
        - tensor or numpy array

    :return positions:
        - numpy array of labels 1
    """
    assert len(image.shape) == 3

    x = np.linspace(0, image.shape[0] - 1, image.shape[0])
    y = np.linspace(0, image.shape[1] - 1, image.shape[1])
    z = np.linspace(0, image.shape[2] - 1, image.shape[2])
    yv, xv, zv = np.meshgrid(y, x, z)
    unique = np.unique(image)[1:]  # don't include 0
    positions = np.zeros((len(unique), 3))
    for i in range(len(unique)):
        label = (image == unique[i]).astype('float32')
        xc = np.sum(label * xv) / np.sum(label)
        yc = np.sum(label * yv) / np.sum(label)
        zc = np.sum(label * zv) / np.sum(label)
        positions[i, 0] = xc
        positions[i, 1] = yc
        positions[i, 2] = zc
    return positions

In [4]:
def calculate_mTRE(xyz_true, xyz_predict):
    assert xyz_true.shape == xyz_predict.shape
    TRE = np.sqrt(np.sum(np.power(xyz_true - xyz_predict, 2), axis=1))
    mTRE = np.mean(TRE)
    return mTRE

In [5]:
def case_TREs(pred_dir, pair_number, num_labels, affine, scale):
    TREs = np.zeros(num_labels)
    for i in range(num_labels):
        label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/fixed_label.nii.gz")
        pred_label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/pred_fixed_label.nii.gz")

        label_np = label.get_fdata()
        label_affine = label.affine

        pred_label_np = pred_label.get_fdata()
        pred_label_affine = pred_label.affine
        
        label_point = nib.affines.apply_affine(affine, extract_centroid(np.round(label_np))*scale)
        pred_point = nib.affines.apply_affine(affine, extract_centroid(np.round(pred_label_np))*scale)
        
        TREs[i] = calculate_mTRE(label_point, pred_point)
        
    return TREs

# Calulating the mTRE for the 3 test cases

In [13]:
# folder path to the prediction data
prediction_dir = "logs/94_final_test/test/"

## Case 2

In [17]:
case2_TREs = case_TREs(prediction_dir, 0, 15, case2_affine, case2_scale)
case2_TREs

array([1.79753499, 0.26693628, 0.84879033, 1.14335395, 1.88379245,
       2.06736938, 1.83329735, 1.56726897, 4.68904725, 2.36189134,
       1.88635303, 2.03235962, 4.08240861, 2.1665045 , 2.6266453 ])

In [18]:
case2_mTRE = np.mean(case2_TREs)
print(f"The mTRE for case 2 was {case2_mTRE}")

The mTRE for case 2 was 2.0835702220168844


## Case 21

In [19]:
case21_TREs = case_TREs(prediction_dir, 1, 16, case21_affine, case21_scale)
case21_TREs

array([30.63088671, 31.55817349, 28.53388818, 31.26500583, 29.5532927 ,
       30.38294457, 32.87606821, 33.39476506, 31.7590634 , 33.05329981,
       31.53157843, 32.16872753, 30.05659811, 30.29417846, 29.80772696,
       32.75879295])

In [20]:
case21_mTRE = np.mean(case21_TREs)
print(f"The mTRE for case 21 was {case21_mTRE}")

The mTRE for case 21 was 31.226561899868635


## Case 23

In [21]:
case23_TREs = case_TREs(prediction_dir, 2, 15, case23_affine, case23_scale)
case23_TREs

array([1.69679684, 2.44954808, 3.31133736, 2.7626878 , 3.68265614,
       1.7193722 , 3.21651715, 3.35398911, 2.37725154, 2.04183939,
       3.19872962, 3.24624183, 2.56858806, 2.95183963, 4.95450721])

In [22]:
case23_mTRE = np.mean(case23_TREs)
print(f"The mTRE for case 23 was {case23_mTRE}")

The mTRE for case 23 was 2.9021267970274662
