In [4]:
import nibabel as nib
import numpy as np

In [5]:
# Extracted information from original nifti files. 
# DeepReg throws away this information and doesn't acurately calculate mTRE

case2_affine = np.array([[ 3.33370864e-01,  1.59845248e-01,  3.33878189e-01, -3.39173775e+01],
       [-1.82626724e-01,  4.61422384e-01, -3.86130475e-02, 4.38701019e+01],
       [-3.21150362e-01, -9.64666754e-02,  3.68540883e-01, 6.26287537e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case21_affine = np.array([[4.00647670e-01, -1.04107566e-01,  2.77501583e-01, -3.58118629e+01],
       [ 1.67786613e-01,  4.64567333e-01, -6.78447485e-02, -8.16707916e+01],
       [-2.44675279e-01,  1.48105368e-01,  4.07874972e-01, -8.83094406e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case23_affine = np.array([[  0.31273046,   0.25455737,   0.29349992, -41.70222473],
       [ -0.28844342,   0.40450525,  -0.04210297,   2.00646901],
       [ -0.2584393 ,  -0.14343421,   0.40214738,   4.35490799],
       [  0.        ,   0.        ,   0.        ,   1.        ]])

DeepReg_rescale = (151, 136, 119)

case2_size = (128, 161, 106)
case2_scale = np.array([case2_size[0] / DeepReg_rescale[0], 
                        case2_size[1] / DeepReg_rescale[1], 
                        case2_size[2] / DeepReg_rescale[2]])

case21_size = (215, 251, 191)
case21_scale = np.array([case21_size[0] / DeepReg_rescale[0], 
                         case21_size[1] / DeepReg_rescale[1], 
                         case21_size[2] / DeepReg_rescale[2]])

case23_size = (106, 147, 134)
case23_scale = np.array([case23_size[0] / DeepReg_rescale[0], 
                         case23_size[1] / DeepReg_rescale[1], 
                         case23_size[2] / DeepReg_rescale[2]])

In [6]:
def extract_centroid(image):
    """
    Extract centroid from nifti images with landmark spheres
    which have integer values according to labels
    Adapted from: https://gist.github.com/mattiaspaul/f4183f525b1cbc65e71ad23298d6436e

    :param image:
        - shape: (dim_1, dim_2, dim_3) or (batch, dim_1, dim_2, dim_3)
        - tensor or numpy array

    :return positions:
        - numpy array of labels 1
    """
    assert len(image.shape) == 3

    x = np.linspace(0, image.shape[0] - 1, image.shape[0])
    y = np.linspace(0, image.shape[1] - 1, image.shape[1])
    z = np.linspace(0, image.shape[2] - 1, image.shape[2])
    yv, xv, zv = np.meshgrid(y, x, z)
    unique = np.unique(image)[1:]  # don't include 0
    positions = np.zeros((len(unique), 3))
    for i in range(len(unique)):
        label = (image == unique[i]).astype('float32')
        xc = np.sum(label * xv) / np.sum(label)
        yc = np.sum(label * yv) / np.sum(label)
        zc = np.sum(label * zv) / np.sum(label)
        positions[i, 0] = xc
        positions[i, 1] = yc
        positions[i, 2] = zc
    return positions

In [7]:
def calculate_mTRE(xyz_true, xyz_predict):
    assert xyz_true.shape == xyz_predict.shape
    TRE = np.sqrt(np.sum(np.power(xyz_true - xyz_predict, 2), axis=1))
    mTRE = np.mean(TRE)
    return mTRE

In [5]:
def case_TREs(pred_dir, pair_number, num_labels, affine, scale):
    TREs = np.zeros(num_labels)
    for i in range(num_labels):
        label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/fixed_label.nii.gz")
        pred_label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/pred_fixed_label.nii.gz")

        label_np = label.get_fdata()
        pred_label_np = pred_label.get_fdata()
        
        label_point = nib.affines.apply_affine(affine, extract_centroid(np.round(label_np))*scale)
        pred_point = nib.affines.apply_affine(affine, extract_centroid(np.round(pred_label_np))*scale)
        
        TREs[i] = calculate_mTRE(label_point, pred_point)
        
    return TREs

# Calulating the mTRE for the 3 test cases

In [6]:
# folder path to the prediction data
prediction_dir = "logs/91_final_test/test/"

## Case 2

In [7]:
case2_TREs = case_TREs(prediction_dir, 0, 15, case2_affine, case2_scale)
case2_TREs

array([5.23071247, 4.76445739, 4.63371352, 5.22994398, 4.54881859,
       5.13412461, 2.93244122, 5.42828216, 7.07171662, 4.91866433,
       4.29902994, 6.54738469, 6.8095317 , 8.37997244, 4.29391131])

In [8]:
case2_mTRE = np.mean(case2_TREs)
print(f"The mTRE for case 2 was {case2_mTRE}")

The mTRE for case 2 was 5.348180331814035


## Case 21

In [9]:
case21_TREs = case_TREs(prediction_dir, 1, 16, case21_affine, case21_scale)
case21_TREs

array([6.09022382, 3.63618927, 4.7318172 , 3.60586097, 4.58234045,
       4.40075041, 4.12417652, 4.90767851, 4.66422887, 3.25541754,
       4.75885731, 4.23768257, 3.57989204, 3.82470047, 3.86659602,
       3.74109874])

In [10]:
case21_mTRE = np.mean(case21_TREs)
print(f"The mTRE for case 21 was {case21_mTRE}")

The mTRE for case 21 was 4.250469418760417


## Case 23

In [11]:
case23_TREs = case_TREs(prediction_dir, 2, 15, case23_affine, case23_scale)
case23_TREs

array([5.89834658, 6.34542079, 7.60935995, 6.03107151, 7.04189091,
       7.65584598, 6.243868  , 7.84795153, 7.48142772, 7.44519124,
       5.05002922, 5.72686916, 5.24219633, 5.80966683, 7.79706751])

In [12]:
case23_mTRE = np.mean(case23_TREs)
print(f"The mTRE for case 23 was {case23_mTRE}")

The mTRE for case 23 was 6.615080216831815


In [11]:
pred_dir = "logs/91_final_test/test/"

In [12]:
pair_number = 0
i = 0
affine = case2_affine
scale = case2_scale

In [13]:
label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/fixed_label.nii.gz")
pred_label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/pred_fixed_label.nii.gz")

label_np = label.get_fdata()
pred_label_np = pred_label.get_fdata()

label_point = nib.affines.apply_affine(affine, extract_centroid(np.round(label_np))*scale)
pred_point = nib.affines.apply_affine(affine, extract_centroid(np.round(pred_label_np))*scale)

In [14]:
calculate_mTRE(label_point, pred_point)

5.230712466074962

In [None]:
label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/fixed_label.nii.gz")
pred_label = nib.load(pred_dir + f"pair_{pair_number}/label_{i}/pred_fixed_label.nii.gz")

label_np = label.get_fdata()
pred_label_np = pred_label.get_fdata()


label_point = nib.affines.apply_affine(affine, extract_centroid(np.round(label_np))*scale)
pred_point = nib.affines.apply_affine(affine, extract_centroid(np.round(pred_label_np))*scale)