In [None]:
%load_ext autoreload
%autoreload 2

#
# Python module import.
#

import sys
sys.path.append("../")
import math
import copy
import pickle
import time

import torch
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

import imodal

device = 'cuda:2'
torch.set_default_dtype(torch.float32)
imodal.Utilities.set_compute_backend('keops')

In [None]:
bas = imodal.Models.Deformable3DImage.load_from_file("../../../data/imagen/27113512/regions/nat_amyg_000027113512_bas_roi.nii.gz")
fu3 = imodal.Models.Deformable3DImage.load_from_file("../../../data/imagen/27113512/regions/nat_amyg_000027113512_fu3_roi.nii.gz")

In [None]:
offset = [1.0997878e+00, -1.2219526e+01, 6.5309704e+00]
angles = [9.6657110e-02,1.8558108e-02, 4.6080920e-02]
rigid_deformation = imodal.Utilities.rigid_deformation3d(angles, offset)

# Rigid registration on FU3 to better match baseline and thus accelerate convergence. This does not modify results as implicit modules of order 1 are rotation and translation invariant.
fu3.apply_affine(rigid_deformation)

In [None]:
deformables_shape = bas.bitmap.shape

aabb = imodal.Utilities.AABB(0., deformables_shape[0], 0., deformables_shape[1], 0., deformables_shape[2])

# points_density = 0.005
points_density = 0.05
lddmm_points = aabb.fill_uniform_density(points_density)

In [None]:
scale = 2./points_density**(1/3)
print(scale)
lddmm = imodal.DeformationModules.ImplicitModule0(3, lddmm_points.shape[0], scale, gd=lddmm_points, nu=0.1)

In [None]:
global_translation = imodal.DeformationModules.GlobalTranslation(3)

In [None]:
model = imodal.Models.RegistrationModel(bas, [global_translation], imodal.Attachment.L2NormAttachment())

In [None]:
model.to_device(device)
fu3.to_device(device)

In [None]:
shoot_solver = 'euler'
shoot_it = 10
costs = {}
fitter = imodal.Models.Fitter(model, optimizer='torch_lbfgs')

In [None]:
fitter.fit([fu3], 100, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

In [None]:
with torch.autograd.no_grad():
    intermediates = {}
    start = time.perf_counter()
    deformed_bitmap = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0]
    print("Elapsed time={} sec".format(time.perf_counter() - start))

In [None]:
import nibabel as nib
nib.save(nib.Nifti1Image(deformed_bitmap.cpu().numpy(), bas.affine.numpy()), "deformed.nii.gz")