This notebook evaluates a bunch of registration models based on their ability to align diffusion tensors.

In [1]:
import torch
import monai
import numpy as np
import pandas as pd
from pathlib import Path
from spatial_derivatives import JacobianOfDDF
from dti_warp import WarpDTI, TensorTransformType, PolarDecompositionMode, MseLossDTI, aoe_dti
import util
import shutil
import time
import ants
import importlib.util
import sys
from collections import defaultdict, namedtuple
from customRandAffine import AffineAugmentationDTI

KeyboardInterrupt: 

In [None]:
device = torch.device('cuda')
spatial_size = (144,144,144)

data_dir = Path('./dti_fit_images_test/')
fa_dir = data_dir/'fa'
dti_dir = data_dir/'dti'
data = [{'dti':str(path), 'fa':str(path.parent.parent/'fa'/path.name), "filename":path.name} for path in dti_dir.glob('*')]

In [None]:
k = ['fa', 'dti']

transform = monai.transforms.Compose([
    monai.transforms.LoadImageD(keys=k),
    monai.transforms.EnsureChannelFirstD(keys=k),
    monai.transforms.SpatialPadD(keys=k, spatial_size=spatial_size, mode="constant"),
    monai.transforms.ToTensorD(keys=k),
])

In [None]:
device = 'cuda'

jac = JacobianOfDDF('cpu')

warp_dti = WarpDTI(
    device='cpu',
    mode='nearest',
    tensor_transform_type=TensorTransformType.FINITE_STRAIN,
    polar_decomposition_mode=PolarDecompositionMode.HALLEY_DYNAMIC_WEIGHTS,
    num_iterations = 9
)

warp_scalar = monai.networks.blocks.Warp(mode='nearest')

mse_dti = MseLossDTI('cpu')

affine_aug = AffineAugmentationDTI(spatial_size, 0.8)

In [None]:
caching = "disk"

if caching == "disk":
    cache_dir = Path('./PersistentDatasetCacheDir')
    if cache_dir.exists():
        shutil.rmtree(cache_dir)
    cache_dir.mkdir(exist_ok=True)

    ds = monai.data.PersistentDataset(data, transform, cache_dir=cache_dir/'train')

elif caching == "memory":
    ds = monai.data.CacheDataset(data, transform)

dl = monai.data.DataLoader(ds, shuffle=True, batch_size=1)

In [None]:
def import_module_from_path(module_path, module_name):
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    module = importlib.util.module_from_spec(spec)
    globals()[module_name] = module
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    
import_module_from_path("models_to_benchmark/2022-09-13-deformable-2be0f3bd.py", 'module_2be0f3bd')
import_module_from_path("models_to_benchmark/dti-2022-10-23a-e11e483.py", 'module_e11e483')
import_module_from_path("models_to_benchmark/dti-2022-10-24a-10da6f4.py", 'module_10da6f4')
import_module_from_path("models_to_benchmark/dti-2022-10-20b-e17e67307.py", 'module_e17e67307')
import_module_from_path("models_to_benchmark/dti-2022-10-20b-0a3ef0d334f.py", 'module_0a3ef0d334f')


In [None]:
def noop_model(fa1, fa2, dti1, dti2):
    b,_,h,w,d = fa1.shape
    ddf = torch.zeros(b,3,h,w,d,dtype=dti1.dtype)
    fa2_warped = fa2
    t = 0
    return ddf, fa2_warped, t

In [None]:
get_nii_path = lambda ants_transforms : [p for p in ants_transforms if '.nii' in Path(p).suffixes][0]
def ants_model(fa1, fa2, dti1, dti2):
    fa1_ants = ants.from_numpy(fa1.cpu().numpy()[0,0])
    fa2_ants = ants.from_numpy(fa2.cpu().numpy()[0,0])
    start_time = time.time()
    ants_reg = ants.registration(fa1_ants, fa2_ants, type_of_transform='SyN')
    t = time.time() - start_time
    fa2_warped = torch.tensor(ants_reg['warpedmovout'].numpy()).unsqueeze(0).unsqueeze(0)
    fwdtransform_path = get_nii_path(ants_reg['fwdtransforms'])
    ddf = monai.transforms.LoadImage(image_only=True)(fwdtransform_path).permute((3,4,0,1,2))
    return ddf, fa2_warped, t

In [None]:
net_2be0f3bd = module_2be0f3bd.create_model(device)
def model_2be0f3bd(fa1, fa2, dti1, dti2):
    fa1_d, fa2_d = fa1.to(device), fa2.to(device)
    start_time = time.time()
    ddf,fa2_warped = net_2be0f3bd.forward_inference(fa1_d, fa2_d)
    t = time.time() - start_time
    return ddf.cpu(), fa2_warped.cpu(), t

In [None]:
net_e11e483 = module_e11e483.create_model(device)
def model_e11e483(fa1, fa2, dti1, dti2):
    dti1_d, dti2_d = dti1.to(device), dti2.to(device)
    fa2_d = fa2.to(device)
    start_time = time.time()
    ddf = net_e11e483(dti1_d, dti2_d, return_warp_only=True)
    fa2_warped = warp_scalar(fa2_d, ddf)
    t = time.time() - start_time
    return ddf.cpu(), fa2_warped.cpu(), t

In [None]:
net_10da6f4 = module_10da6f4.create_model(device)
def model_10da6f4(fa1, fa2, dti1, dti2):
    dti1_d, dti2_d = dti1.to(device), dti2.to(device)
    fa2_d = fa2.to(device)
    start_time = time.time()
    ddf = net_10da6f4(dti1_d, dti2_d, return_warp_only=True)
    fa2_warped = warp_scalar(fa2_d, ddf)
    t = time.time() - start_time
    return ddf.cpu(), fa2_warped.cpu(), t

In [None]:
net_e17e67307 = module_e17e67307.create_model(device)
def model_e17e67307(fa1, fa2, dti1, dti2):
    dti1_d, dti2_d = dti1.to(device), dti2.to(device)
    fa2_d = fa2.to(device)
    start_time = time.time()
    ddf = net_e17e67307(dti1_d, dti2_d, return_warp_only=True)
    fa2_warped = warp_scalar(fa2_d, ddf)
    t = time.time() - start_time
    return ddf.cpu(), fa2_warped.cpu(), t

In [None]:
net_0a3ef0d334f = module_0a3ef0d334f.create_model(device)
def model_0a3ef0d334f(fa1, fa2, dti1, dti2):
    dti1_d, dti2_d = dti1.to(device), dti2.to(device)
    fa2_d = fa2.to(device)
    start_time = time.time()
    ddf = net_0a3ef0d334f(dti1_d, dti2_d, return_warp_only=True)
    fa2_warped = warp_scalar(fa2_d, ddf)
    t = time.time() - start_time
    return ddf.cpu(), fa2_warped.cpu(), t

In [None]:
Metrics = namedtuple("Metrics", "fa_mse, fa_ncc, dti_mse, weighted_dti_mse, aoe, folds, t")

def compute_metrics(dti1, dti2, fa1, fa2, model):
    
    ddf, fa2_warped, t = model(fa1, fa2, dti1, dti2)
    
    fa_mse = ((fa2_warped - fa1)**2).mean().item()
    fa_ncc = -util.ncc_loss(fa1,fa2_warped).item()
    folds = (jac(ddf)<0).sum().item()

    dti2_warped = warp_dti(dti2, ddf)

    dti_mse = mse_dti(dti1, dti2_warped).item()
    weighted_dti_mse = mse_dti(dti1, dti2_warped, weighting=fa1*0.75+0.25).item()
    aoe = aoe_dti(dti1, dti2_warped, fa1).item()
    
    return Metrics(fa_mse, fa_ncc, dti_mse, weighted_dti_mse, aoe, folds, t)

In [None]:
# Models to evaluate
models = {
    "no-op": noop_model,
    "ants": ants_model,
    "fa-driven": model_2be0f3bd,
    "dti-driven-no-affaug-less-reg": model_e17e67307,
    "dti-driven-no-affaug": model_e11e483,
    "dti-driven-affaug": model_10da6f4,
    "dti-driven-L2-noaff-fullFAwt": model_0a3ef0d334f,
}

# Number of passes to make over the test data
num_passes = 1

model_metrics = { k: defaultdict(list) for k in models.keys() }

for i in range(num_passes):
    print(f"pass {i+1}/{num_passes}")
    
    dl_iter = iter(dl)
    j=0
    while True:
        j+=1
        
        try:
            d1 = next(dl_iter)
            d2 = next(dl_iter)
        except StopIteration:
            break
        
        print(f"\timg pair {j}/{len(ds)//2}")
            
        dti1 = d1['dti']
        dti2 = d2['dti']
        fa1 = d1['fa']
        fa2 = d2['fa']
        
#         fa1, fa2, dti1, dti2 = affine_aug(fa1, fa2, dti1, dti2)
        
        print('\t\t',end='')
        for model_key, model in models.items():
            print(model_key[:6],'...',end='')
            metrics = compute_metrics(dti1, dti2, fa1, fa2, model)
            for metric_name, metric_value in metrics._asdict().items():
                model_metrics[model_key][metric_name].append(metric_value)
        print()
            
        
        

In [None]:
df_dict_means = defaultdict(list)
df_dict_medians = defaultdict(list)
metric_names = list(Metrics._fields)
for model_key,metrics in model_metrics.items():
    
    for metric_name in metric_names:
        metric_list = metrics[metric_name]
        mean_metric = np.mean(metric_list)
        median_metric = np.median(metric_list)
                
        df_dict_means[model_key].append(mean_metric)
        df_dict_medians[model_key].append(median_metric)

df_means = pd.DataFrame.from_dict(df_dict_means, orient='index', columns=metric_names)
df_medians = pd.DataFrame.from_dict(df_dict_medians, orient='index', columns=metric_names)

In [None]:
print("Means of metrics:")
df_means

In [None]:
print("Medians of metrics:")
df_medians

In [None]:
df_means.to_csv('evaluation_tables/evaluation_means.csv')
df_medians.to_csv('evaluation_tables/evaluation_medians.csv')