In [None]:
import torch
import monai
from pathlib import Path
from spatial_derivatives import JacobianOfDDF
from dti_warp import WarpDTI, TensorTransformType, PolarDecompositionMode, MseLossDTI, aoe_dti
import util
import shutil
import time
import ants

In [None]:
device = torch.device('cuda')
spatial_size = (144,144,144)

data_dir = Path('./dti_fit_images_test/')
fa_dir = data_dir/'fa'
dti_dir = data_dir/'dti'
data = [{'dti':str(path), 'fa':str(path.parent.parent/'fa'/path.name), "filename":path.name} for path in dti_dir.glob('*')]

In [None]:
k = ['fa', 'dti']

transform = monai.transforms.Compose([
    monai.transforms.LoadImageD(keys=k),
    monai.transforms.EnsureChannelFirstD(keys=k),
    monai.transforms.SpatialPadD(keys=k, spatial_size=spatial_size, mode="constant"),
    monai.transforms.ToTensorD(keys=k),
])

In [None]:
device = 'cpu'

jac = JacobianOfDDF(device)

warp_dti = WarpDTI(
    device=device,
    tensor_transform_type=TensorTransformType.FINITE_STRAIN,
    polar_decomposition_mode=PolarDecompositionMode.HALLEY_DYNAMIC_WEIGHTS,
    num_iterations = 9
)

warp_scalar = monai.networks.blocks.Warp(mode='nearest')

mse_dti = MseLossDTI(device)

In [None]:
caching = "disk"

if caching == "disk":
    cache_dir = Path('./PersistentDatasetCacheDir')
    if cache_dir.exists():
        shutil.rmtree(cache_dir)
    cache_dir.mkdir(exist_ok=True)

    ds = monai.data.PersistentDataset(data, transform, cache_dir=cache_dir/'train')

elif caching == "memory":
    ds = monai.data.CacheDataset(data, transform)

dl = monai.data.DataLoader(ds, shuffle=True, batch_size=1)

In [None]:
it = iter(dl)
d1 = next(it)
d2 = next(it)

dti1 = d1['dti']
dti2 = d2['dti']
fa1 = d1['fa']
fa2 = d2['fa']

In [None]:
get_nii_path = lambda ants_transforms : [p for p in ants_transforms if '.nii' in Path(p).suffixes][0]
def ants_model(fa1, fa2, dti1, dti2):
    fa1_ants = ants.from_numpy(fa1.numpy()[0,0])
    fa2_ants = ants.from_numpy(fa2.numpy()[0,0])
    ants_reg = ants.registration(fa1_ants, fa2_ants, type_of_transform='SyN')
    fa2_warped = torch.tensor(ants_reg['warpedmovout'].numpy()).unsqueeze(0).unsqueeze(0)
    fwdtransform_path = get_nii_path(ants_reg['fwdtransforms'])
    ddf = monai.transforms.LoadImage(image_only=True)(fwdtransform_path).permute((3,4,0,1,2))
    return ddf, fa2_warped
    

def compute_metrics(dti1, dti2, fa1, fa2, model):
    start_time = time.time()
    ddf, fa2_warped = model(fa1, fa2, dti1, dti2)
    t = time.time() - start_time
    
    fa_mse = ((fa2_warped - fa1)**2).mean().item()
    fa_ncc = -util.ncc_loss(fa1,fa2_warped).item()
    folds = (jac(ddf)<0).sum().item()

    dti2_warped = warp_dti(dti2, ddf)

    dti_mse = mse_dti(dti1, dti2_warped).item()
    aoe = aoe_dti(dti1, dti2, fa1).item()
    
    return fa_mse, fa_ncc, dti_mse, aoe, folds, t

In [None]:
fa_mse, fa_ncc, dti_mse, aoe, folds, t = compute_metrics(dti1, dti2, fa1, fa2, ants_model)

print("FA mse:", fa_mse)
print("FA ncc:", fa_ncc)
print("dti mse:", dti_mse)
print("dti AOE:", aoe)
print("folds:", folds)
print("time:", t)