In [2]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import math
import scipy
import torch
import matplotlib.pyplot as plt

import implicitmodules.torch as dm

In [17]:
source_image = dm.Utilities.load_greyscale_image("../../data/images/bar_a.png", origin='lower')
target_image = dm.Utilities.load_greyscale_image("../../data/images/bar_b.png", origin='lower')

In [18]:
%matplotlib qt5
plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(source_image, origin='lower')

plt.subplot(1, 2, 2)
plt.title("Target image")
plt.imshow(target_image, origin='lower')

plt.show()

In [19]:
sig_smooth = 15
im0 = torch.tensor(scipy.ndimage.gaussian_filter(source_image, sig_smooth))
im1 = torch.tensor(scipy.ndimage.gaussian_filter(target_image, sig_smooth))

In [20]:
center = torch.tensor([[55., 85.]])

In [21]:
%matplotlib qt5

plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(im0, origin='lower')
plt.plot(center[0, 0].numpy(), center[0, 1].numpy(), 'x')

plt.subplot(1, 2, 2)
plt.title("Target image")
plt.imshow(im1, origin='lower')

plt.show()

In [22]:
translation = dm.DeformationModules.ImplicitModule0(2, 1, 200., nu=0.1, gd=center.clone().requires_grad_())

In [23]:
source = dm.Models.DeformableImage(im0.t())
target = dm.Models.DeformableImage(im1.t())

In [25]:
model = dm.Models.RegistrationModel(source, [translation], dm.Attachment.EuclideanPointwiseDistanceAttachment(), fit_gd=[False], lam=100.)

In [26]:
shoot_solver='rk4'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

In [28]:
fitter.fit(target, 100, costs=costs, options={'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

Starting optimization with method torch LBFGS
Initial cost={'deformation': 0.0, 'attach': 35674.7265625}


Time: 13.322812743950635
Iteration: 0
Costs
deformation=5264.208984375
attach=8195.568359375
Total cost=13459.77734375


Time: 16.566908759996295
Iteration: 1
Costs
deformation=5240.96435546875
attach=8216.6015625
Total cost=13457.56591796875


Time: 18.822844427952077
Iteration: 2
Costs
deformation=5240.96435546875
attach=8216.6015625
Total cost=13457.56591796875
Optimisation process exited with message: Convergence achieved.
Final cost=13457.56591796875
Model evaluation count=35
Time elapsed = 18.823069280944765


In [33]:
intermediates = {}
with torch.autograd.no_grad():
    deformed_image = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0].t()

translation_center = model.init_manifold[1].gd.detach().flatten().tolist()
translation_moment = model.init_manifold[1].cotan.detach().flatten().tolist()
translation_center_end = intermediates['states'][-1][1].gd.flatten().tolist()

print(translation_center)
print(translation_center_end)
print(translation_moment)

[55.0, 85.0]
[148.40435791015625, 113.2249984741211]
[0.003161848057061434, 0.000965333019848913]


In [34]:
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.title("Source image")
plt.imshow(source_image, origin='lower')
plt.plot(center.flatten().tolist()[0], center.flatten().tolist()[1], 'X')

plt.subplot(1, 3, 2)
plt.title("Fitted image")
plt.imshow(deformed_image, origin='lower')
plt.plot(translation_center[0], translation_center[1], 'X')
plt.plot(translation_center_end[0], translation_center_end[1], 'X')
plt.quiver(translation_center[0], translation_center[1],
           translation_moment[0], translation_moment[1])

plt.subplot(1, 3, 3)
plt.title("target image")
plt.imshow(target_image, origin='lower')

plt.show()