In [61]:
import reg_mri
import os
from glob import glob
from utils import compute_mean_dice
import nibabel as nib
from scipy.spatial.distance import dice
import numpy as np
import itk
import SimpleITK as sitk
import scipy.ndimage
import scipy
import matplotlib.pyplot as plt
from transforms_dict import getRegistrationEvalInverseTransformForMRI, SaveTransformForMRI
from tqdm import tqdm
import monai
import subprocess
import torch
import torch.nn as nn

## To get landmarks + dices results, use reg_landmarks.ipynb

No registration
Labels: Dice: 0.8880 / Haussdorf: 10.6292 / IoU: 0.8002 / Sens: 0.8953 / Spec: 0.9818
Landmarks: Mean: 4.313 / Max: X.XXXX / Min: X.XXXX
Landmarks: Barycenter: Mean: 4.282

Affine
Labels: Dice: 0.9772 / Haussdorf: 8.7946 / IoU: 0.9299 / Sens: 0.9876 / Spec: 1.0205
Landmarks: Mean: 1.556 / Max: X.XXXX / Min: X.XXXX
Landmarks Barycenter: Mean: 1.556

SyN
Labels: Dice: 0.9884 / Haussdorf: 8.5734 / IoU: 0.9495 / Sens: 0.9970 / Spec: 1.0233
Landmarks: Mean: 1.179 / Max: X.XXXX / Min: X.XXXX
Landmarks Barycenter: Mean: 1.054

base-0.1
Loss: loss: 0.1401
Labels: Dice: 0.9709 / Haussdorf: 8.0573 / IoU: 0.9437 / Sens: 0.9671 / Spec: 0.9962
Landmarks: Mean: 1.3358 / Max: 2.3629 / Min: 0.6060
Landmarks Barycenter: Mean: 0.9387

affine-0.1
Loss: 0.1492
Labels: Dice: 0.9740 / Haussdorf: 7.4006 / IoU: 0.9494 / Sens: 0.9682 / Spec: 0.9970
Landmarks: Mean: 1.3451 / Max: 2.3593 / Min: 0.5665
Landmarks Barycenter: Mean: 0.9929

base-8.0
Loss: 0.1390
Labels: Dice: 0.9589 / Haussdorf: 8.4820 / IoU: 0.9215 / Sens: 0.9729 / Spec: 0.9914
Landmarks: Mean: 1.2353 / Max: 2.1447 / Min: 0.6253
Landmarks Barycenter: Mean: 1.0061

affine-8.0
Loss: 0.1339
Labels: Dice: 0.9591 / Haussdorf: 8.5658 / IoU: 0.9219 / Sens: 0.9748 / Spec: 0.9911
Landmarks: Mean: 1.2158 / Max: 2.1326 / Min: 0.6025
Landmarks Barycenter: Mean: 1.0335

In [64]:
outfolders = [
    "paper-old-0.1",
    "paper-affine-0.1",
    "paper-old-8.0",
    "paper-affine-8.0",
]

ddfs_truth = sorted(glob(os.path.join("output", "Feminad", "ANTS", "DeformableWarp", "*")))
log10jacs_truth = sorted(glob(os.path.join("output", "Feminad", "ANTS", "Log10Jacobian_Deformable", "*")))

print('-'*10)
for i, outfolder in enumerate(outfolders):
    megaoutfolder = 'output/Feminad/' + outfolder + '/'
    
    ddfs_pred = sorted(glob(os.path.join(megaoutfolder, "DeformableWarp", "*")))
    ddf_L1_distance = 0
    ddf_L2_distance = 0
    ddf_var = 0
    for j in range(len(ddfs_pred)):
        ddf_pred = nib.load(ddfs_pred[j]).get_fdata().reshape((1,3,128,128,128))
        ddf_truth = nib.load(ddfs_truth[j]).get_fdata().reshape((1,3,128,128,128))  
        pred = torch.from_numpy(ddf_pred)
        truth = torch.from_numpy(ddf_truth)     
        ddf_L1_distance += nn.L1Loss()(pred, truth)
        ddf_L2_distance += nn.MSELoss()(pred, truth)
        ddf_var += torch.var(pred)
    ddf_L1_distance /= len(ddfs_pred)
    ddf_L2_distance /= len(ddfs_pred)
    ddf_var /= len(ddfs_pred)
    
    log10jacs_pred = sorted(glob(os.path.join(megaoutfolder, "Log10Jacobian_Deformable", "*")))    
    log10jac_L1_distance = 0
    log10jac_L2_distance = 0
    log10jac_var = 0
    for j in range(len(log10jacs_pred)):
        log10jac_pred = nib.load(log10jacs_pred[j]).get_fdata().reshape((1,1,128,128,128))
        log10jac_truth = nib.load(log10jacs_truth[j]).get_fdata().reshape((1,1,128,128,128))
        pred = torch.from_numpy(log10jac_pred)
        truth = torch.from_numpy(log10jac_truth)  
        log10jac_L1_distance += nn.L1Loss()(pred, truth)
        log10jac_L2_distance += nn.MSELoss()(pred, truth)
        log10jac_var += torch.var(pred)
    log10jac_L1_distance /= len(log10jacs_pred)
    log10jac_L2_distance /= len(log10jacs_pred)
    log10jac_var /= len(log10jacs_pred)
    
    print(outfolder)
    print("DDF: L2Mean: {:.4f} / L1Mean: {:.4f} / Variance: {:.4f}".format(ddf_L2_distance.item(), ddf_L1_distance.item(), ddf_var.item()))
    print("Log10Jac: L2Mean: {:.4f} / L1Mean: {:.4f} / Variance:  {:.4f}".format(log10jac_L2_distance.item(), log10jac_L1_distance.item(), log10jac_var.item()))
    print('-'*10)

----------
paper-old-0.1
DDF: L2Mean: 0.8263 / L1Mean: 0.6330 / Variance: 0.1864
Log10Jac: L2Mean: 0.0387 / L1Mean: 0.1324 / Variance:  0.0215
----------
paper-affine-0.1
DDF: L2Mean: 2.9852 / L1Mean: 1.3324 / Variance: 2.4058
Log10Jac: L2Mean: 0.0517 / L1Mean: 0.1514 / Variance:  0.0359
----------
paper-old-8.0
DDF: L2Mean: 0.8054 / L1Mean: 0.6156 / Variance: 0.0995
Log10Jac: L2Mean: 0.0193 / L1Mean: 0.1052 / Variance:  0.0011
----------
paper-affine-8.0
DDF: L2Mean: 0.9940 / L1Mean: 0.7318 / Variance: 0.2316
Log10Jac: L2Mean: 0.0195 / L1Mean: 0.1059 / Variance:  0.0013
----------


paper-old-0.1
DDF: L2Mean: 0.8263 / L1Mean: 0.6330 / Variance: 0.1864
Log10Jac: L2Mean: 0.0387 / L1Mean: 0.1324 / Variance:  0.0215

paper-affine-0.1
DDF: L2Mean: 2.9852 / L1Mean: 1.3324 / Variance: 2.4058
Log10Jac: L2Mean: 0.0517 / L1Mean: 0.1514 / Variance:  0.0359

paper-old-8.0
DDF: L2Mean: 0.8054 / L1Mean: 0.6156 / Variance: 0.0995
Log10Jac: L2Mean: 0.0193 / L1Mean: 0.1052 / Variance:  0.0011

paper-affine-8.0
DDF: L2Mean: 0.9940 / L1Mean: 0.7318 / Variance: 0.2316
Log10Jac: L2Mean: 0.0195 / L1Mean: 0.1059 / Variance:  0.0013

## To get landmarks + dices results, use reg_landmarks.ipynb
No registration
Loss: X / Dice: 0.8880 / Haussdorf: 10.6292 / IoU: 0.8002 / Sens: 0.8953 / Spec: 0.9818 / LMean: 4.313 / LMax: X / LMin: X / LBarycenter: 4.282

Affine
Loss: X / Dice: 0.9772 / Haussdorf: 8.7946 / IoU: 0.9299 / Sens: 0.9876 / Spec: 1.0205 / LMean: 1.556 / LMax: X / LMin: X / LBarycenter: 1.556

SyN
Loss: X / Dice: 0.9884 / Haussdorf: 8.5734 / IoU: 0.9495 / Sens: 0.9970 / Spec: 1.0233 / LMean: 1.179 / LMax: X / LMin: X / LBarycenter: 1.054

base-0.1
Loss: 0.1401 / Dice: 0.9709 / Haussdorf: 8.0573 / IoU: 0.9437 / Sens: 0.9671 / Spec: 0.9962 / LMean: 1.3358 / LMax: 2.3629 / LMin: 0.6060 / LBarycenter: 0.9387 / DDFL2Mean: 0.8263 / DDFL1Mean: 0.6330 / DDFVar: 0.1864 / JacL2Mean: 0.0387 / JacL1Mean: 0.1324 / JacVar:  0.0215

affine-0.1
Loss: 0.1492 / Dice: 0.9740 / Haussdorf: 7.4006 / IoU: 0.9494 / Sens: 0.9682 / Spec: 0.9970 / LMean: 1.3451 / LMax: 2.3593 / LMin: 0.5665 / LBarycenter: 0.9929 / DDFL2Mean: 2.9852 / DDFL1Mean: 1.3324 / DDFVar: 2.4058 / JacL2Mean: 0.0517 / JacL1Mean: 0.1514 / JacVar: 0.0359

base-8.0
Loss: 0.1390 / Dice: 0.9589 / Haussdorf: 8.4820 / IoU: 0.9215 / Sens: 0.9729 / Spec: 0.9914 / LMean: 1.2353 / LMax: 2.1447 / LMin: 0.6253 / LBarycenter: 1.0061 / DDFL2Mean: 0.8054 / DDFL1Mean: 0.6156 / DDFVar: 0.0995 / JacL2Mean: 0.0193 / JacL1Mean: 0.1052 / JacVar:  0.0011

affine-8.0
Loss: 0.1339 / Dice: 0.9591 / Haussdorf: 8.5658 / IoU: 0.9219 / Sens: 0.9748 / Spec: 0.9911 / LMean: 1.2158 / LMax: 2.1326 / LMin: 0.6025 / LBarycenter: 1.0335 /DDFL2Mean: 0.9940 / DDFL1Mean: 0.7318 / DDFVariance: 0.2316 / JacL2Mean: 0.0195 / JacL1Mean: 0.1059 / JacVar:  0.0013

In [94]:
outfolders = [
    "paper-old-0.1",
    "paper-affine-0.1",
    "paper-old-8.0",
    "paper-affine-8.0",
]

jacs_truth = sorted(glob(os.path.join("output", "Feminad", "ANTS", "Jacobian_Deformable", "*")))
truth_var = 0
truth_folding = 0
for j in range(len(jacs_truth)):
    jac_truth = nib.load(jacs_truth[j]).get_fdata().reshape((1,1,128,128,128))  
    truth = torch.from_numpy(jac_truth)  
    truth_var += torch.var(truth)
    truth_folding += np.sum(np.array(truth) >= 0) / (128*128*128)
truth_var /= len(jacs_truth)
truth_folding /= len(jacs_truth)
print('ANTS')
print("Jac: Var: {:.4f} / Folding: {:.4f}".format(truth_var, truth_folding))
print('-'*10)
    
for i, outfolder in enumerate(outfolders):
    megaoutfolder = 'output/Feminad/' + outfolder + '/'    
    jacs_pred = sorted(glob(os.path.join(megaoutfolder, "Jacobian_Deformable", "*")))
    jacs_var = 0
    jacs_folding = 0
    for j in range(len(jacs_pred)):
        jac_pred = nib.load(jacs_pred[j]).get_fdata().reshape((1,1,128,128,128))
        pred = torch.from_numpy(jac_pred)   
        jacs_var += torch.var(pred)
        jacs_folding += np.sum(np.array(pred) >= 0) / (128*128*128)
    jacs_var /= len(jacs_pred)    
    jacs_folding /= len(jacs_pred)
    
    print(outfolder)
    print("Jac: Var: {:.4f} / Folding: {:.4f}".format(jacs_var, jacs_folding))
    print('-'*10)

ANTS
Jac: Var: 0.0197 / Folding: 1.0000
----------
paper-old-0.1
Jac: Var: 0.0173 / Folding: 1.0000
----------
paper-affine-0.1
Jac: Var: 0.0252 / Folding: 1.0000
----------
paper-old-8.0
Jac: Var: 0.0012 / Folding: 1.0000
----------
paper-affine-8.0
Jac: Var: 0.0014 / Folding: 1.0000
----------
