In [3]:
import nibabel as nib
import numpy as np

In [96]:
case1_affine = np.array([[ 3.90691876e-01,  1.30452037e-01,  2.81675309e-01, -1.96299267e+01],
       [ 1.55749843e-02,  4.42654282e-01, -2.28414789e-01, -5.37206192e+01],
       [-3.09405476e-01,  1.87006667e-01,  3.44178468e-01, 2.05047684e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case2_affine = np.array([[ 3.33370864e-01,  1.59845248e-01,  3.33878189e-01, -3.39173775e+01],
       [-1.82626724e-01,  4.61422384e-01, -3.86130475e-02, 4.38701019e+01],
       [-3.21150362e-01, -9.64666754e-02,  3.68540883e-01, 6.26287537e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

tester = np.array([[ 3.33370864e-01,  1.59845248e-01,  3.33878189e-01,
        -3.39173775e+01],
       [-1.82626724e-01,  4.61422384e-01, -3.86130475e-02,
         4.38701019e+01],
       [-3.21150362e-01, -9.64666754e-02,  3.68540883e-01,
         6.26287537e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00]])


case12_affine = np.array([[  0.41881981,   0.07436115,   0.25797391, -57.30677795],
       [  0.06075069,   0.43993044,  -0.22682622, -23.42017174],
       [ -0.26014873,   0.22244968,   0.36234942, -19.91741753],
       [  0.        ,   0.        ,   0.        ,   1.        ]])

case13_affine = np.array([[ 4.08450842e-01, -1.18215449e-01,  2.58995205e-01, -2.38458767e+01],
       [ 1.57638997e-01,  4.71258372e-01, -3.36058326e-02, -7.20230942e+01],
       [-2.37171665e-01,  1.09639972e-01,  4.23698217e-01, -5.00180817e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case21_affine = np.array([[4.00647670e-01, -1.04107566e-01,  2.77501583e-01, -3.58118629e+01],
       [ 1.67786613e-01,  4.64567333e-01, -6.78447485e-02, -8.16707916e+01],
       [-2.44675279e-01,  1.48105368e-01,  4.07874972e-01, -8.83094406e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, 1.00000000e+00]])

case23_affine = np.array([[  0.31273046,   0.25455737,   0.29349992, -41.70222473],
       [ -0.28844342,   0.40450525,  -0.04210297,   2.00646901],
       [ -0.2584393 ,  -0.14343421,   0.40214738,   4.35490799],
       [  0.        ,   0.        ,   0.        ,   1.        ]])

In [5]:
def extract_centroid(image):
    """
    Extract centroid from nifti images with landmark spheres
    which have integer values according to labels
    Adapted from: https://gist.github.com/mattiaspaul/f4183f525b1cbc65e71ad23298d6436e

    :param image:
        - shape: (dim_1, dim_2, dim_3) or (batch, dim_1, dim_2, dim_3)
        - tensor or numpy array

    :return positions:
        - numpy array of labels 1
    """
    assert len(image.shape) == 3

    x = np.linspace(0, image.shape[0] - 1, image.shape[0])
    y = np.linspace(0, image.shape[1] - 1, image.shape[1])
    z = np.linspace(0, image.shape[2] - 1, image.shape[2])
    yv, xv, zv = np.meshgrid(y, x, z)
    unique = np.unique(image)[1:]  # don't include 0
    positions = np.zeros((len(unique), 3))
    for i in range(len(unique)):
        label = (image == unique[i]).astype('float32')
        xc = np.sum(label * xv) / np.sum(label)
        yc = np.sum(label * yv) / np.sum(label)
        zc = np.sum(label * zv) / np.sum(label)
        positions[i, 0] = xc
        positions[i, 1] = yc
        positions[i, 2] = zc
    return positions

In [59]:
def calculate_mTRE(xyz_true, xyz_predict):
    assert xyz_true.shape == xyz_predict.shape
    TRE = np.sqrt(np.sum(np.power(xyz_true - xyz_predict, 2), axis=1))
    mTRE = np.mean(TRE)
    return mTRE

In [31]:
prediction_dir = "logs/20210413-124713/test/"

In [105]:
label = nib.load(prediction_dir + "pair_0/label_0/fixed_label.nii.gz")
pred_label = nib.load(prediction_dir + "pair_0/label_0/moving_label.nii.gz")

label_np = label.get_fdata()
label_affine = label.affine

pred_label_np = pred_label.get_fdata()
pred_label_affine = pred_label.affine

In [106]:
label_point = nib.affines.apply_affine(tester, extract_centroid(np.round(label_np)))
pred_point = nib.affines.apply_affine(tester, extract_centroid(np.round(pred_label_np)))

In [109]:
np.where(label_np == 0)

(array([  0,   0,   0, ..., 150, 150, 150]),
 array([  0,   0,   0, ..., 135, 135, 135]),
 array([  0,   1,   2, ..., 116, 117, 118]))

In [110]:
label_np.shape

(151, 136, 119)

In [107]:
label_point

array([[19.05989685, 53.55904318, 61.93159373]])

In [108]:
pred_point

array([[22.17915171, 51.24169973, 65.75868202]])

In [103]:
label_point

array([[19.05989685, 53.55904318, 61.93159373]])

In [104]:
pred_point

array([[22.17915171, 51.24169973, 65.75868202]])

In [95]:
calculate_mTRE(label_point, pred_point)

2.3968798792285693

In [29]:
label_point

array([[19.05989685, 53.55904318, 61.93159373]])

In [30]:
pred_point

array([[22.17915171, 51.24169973, 65.75868202]])