In [None]:
%load_ext autoreload
%autoreload 2

#
# Python module import.
#

import sys
sys.path.append("../")
import math
import copy
import pickle
import time

import torch
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

import imodal

device = 'cpu'
torch.set_default_dtype(torch.float32)
imodal.Utilities.set_compute_backend('keops')

In [None]:
bas = imodal.Models.Deformable3DImage.load_from_file("../../../data/imagen/27113512/regions/nat_amyg_000027113512_bas_roi.nii.gz")
fu2 = imodal.Models.Deformable3DImage.load_from_file("../../../data/imagen/27113512/regions/nat_amyg_000027113512_fu2_roi.nii.gz")
fu3 = imodal.Models.Deformable3DImage.load_from_file("../../../data/imagen/27113512/regions/nat_amyg_000027113512_fu3_roi.nii.gz")

In [None]:
def apply_rigid_deformation(image, offset, angles):
    rigid_deformation = imodal.Utilities.rigid_deformation3d(angles, offset)
    # image.apply_affine(rigid_deformation)

In [None]:
fu2_offset = [6.9561116e+01, 3.0417949e+01, -5.6975957e+01]
fu2_angles = [1.7219040e+00, 4.5891569e-02, -1.6196764e+00]
#fu2_angles = [angle/180.*math.pi for angle in fu2_angles]

fu3_offset = [1.0997878e+00, -1.2219526e+01, 6.5309704e+00]
fu3_angles = [9.6657110e-02, 1.8558108e-02, 4.6080920e-02]
fu3_angles = [angle/180.*math.pi for angle in fu3_angles]

# Rigid registration to better match baseline and thus accelerate convergence.
imodal.Utilities.apply_rigid_deformation(fu2, fu2_offset, fu2_angles)
imodal.Utilities.apply_rigid_deformation(fu3, fu3_offset, fu3_angles)

fu2.save_to_file("fu2_registered.nii.gz")
fu3.save_to_file("fu3_registered.nii.gz")

In [None]:
bas_points = imodal.Utilities.apply_linear_transform_3d(imodal.Utilities.mask_to_indices(bas.bitmap > 0.), bas.affine)

aabb = imodal.Utilities.AABB.build_from_points(bas_points)
bas_points_left = bas_points[bas_points[:, 0] < 0.]
bas_points_right = bas_points[bas_points[:, 0] >= 0.]

aabb_left = imodal.Utilities.AABB.build_from_points(bas_points_left)
aabb_right = imodal.Utilities.AABB.build_from_points(bas_points_right)

print(aabb_left.shape)
print(aabb_right.shape)

In [None]:
lddmm_sigma = 5.
lddmm_points_density = 1./lddmm_sigma**2
print(lddmm_points_density)
lddmm_points_left = aabb_left.fill_uniform_density(lddmm_points_density)
lddmm_points_right = aabb_right.fill_uniform_density(lddmm_points_density)

lddmm_points = torch.cat([lddmm_points_left, lddmm_points_right])
print("LDDMM control points={}".format(lddmm_points.shape[0]))

In [None]:
print(lddmm_points_density*aabb_left.volume)
print(lddmm_points_density*aabb_right.volume)
print(lddmm_points_density*aabb_left.volume+lddmm_points_density*aabb_right.volume)

In [None]:
lddmm = imodal.DeformationModules.ImplicitModule0(3, lddmm_points.shape[0], lddmm_sigma, nu=0.1, gd=lddmm_points)

In [None]:
global_translation = imodal.DeformationModules.GlobalTranslation(3)

In [None]:
model = imodal.Models.RegistrationModel(copy.deepcopy(bas), [lddmm], imodal.Attachment.L2NormAttachment())

In [None]:
model.to_device(device)
fu3.to_device(device)

In [None]:
shoot_solver = 'euler'
shoot_it = 10
costs = {}
#fitter = imodal.Models.Fitter(model, optimizer='gd')
fitter = imodal.Models.Fitter(model, optimizer='torch_lbfgs')


In [None]:
#fitter.fit([fu3], 100, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'init_step_length': 1e-13})
fitter.fit([fu3], 100, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})


In [None]:
with torch.autograd.no_grad():
    intermediates = {}
    start = time.perf_counter()
    deformed_bitmap = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0]
    print("Elapsed time={} sec".format(time.perf_counter() - start))

In [None]:
%matplotlib qt5
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(fu3.bitmap[:, :, 59])
plt.subplot(1, 2, 2)
plt.imshow(deformed_bitmap[:, :, 59])
plt.show()

In [None]:
%matplotlib qt5
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(fu3.bitmap[:, 22, :])
plt.subplot(1, 2, 2)
plt.imshow(deformed_bitmap[:, 22, :])
plt.show()

In [None]:
import nibabel as nib
nib.save(nib.Nifti1Image(deformed_bitmap.cpu().numpy(), bas.affine.numpy()), "deformed.nii.gz")