In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import math
import scipy
import torch
import matplotlib.pyplot as plt

import implicitmodules.torch as dm

In [None]:
source_image = dm.Utilities.load_greyscale_image("../../data/images/bar_a.png", origin='lower')
target_image = dm.Utilities.load_greyscale_image("../../data/images/bar_b.png", origin='lower')

%matplotlib qt5
plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(source_image, origin='lower')

plt.subplot(1, 2, 2)
plt.title("Target image")
plt.imshow(target_image, origin='lower')

plt.show()

In [None]:
pos, im0 = dm.Utilities.sample_from_greyscale(source_image, 0., centered=False, normalise_weights=False, normalise_position=False)
im1 = dm.Utilities.deformed_intensities(pos-0.5, im0.view_as(source_image))

plt.subplot(1, 2, 1)
plt.imshow(im0.view_as(source_image), origin='lower')

plt.subplot(1, 2, 2)
plt.imshow(im1, origin='lower')

plt.show()

In [None]:
sig_smooth = 15
im0 = torch.tensor(scipy.ndimage.gaussian_filter(source_image, sig_smooth))
im1 = torch.tensor(scipy.ndimage.gaussian_filter(target_image, sig_smooth))

In [None]:
center = torch.tensor([[55., 85.]])

In [None]:
%matplotlib qt5

plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(im0, origin='lower')
plt.plot(center[0, 0].numpy(), center[0, 1].numpy(), 'x')

plt.subplot(1, 2, 2)
plt.title("Target image")
plt.imshow(im1, origin='lower')

plt.show()

In [None]:
translation = dm.DeformationModules.ImplicitModule0(2, 1, 200., nu=0.1, gd=center.clone().requires_grad_())

In [None]:
model = dm.Models.ModelImageRegistration(im0, [translation], dm.Attachment.EuclideanPointwiseDistanceAttachment(), fit_gd=[False], lam=100.)

In [None]:
shoot_solver='rk4'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

In [None]:
fitter.fit(im1.clone(), 100, costs=costs, options={'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

In [None]:
intermediates = {}
with torch.autograd.no_grad():
    deformed_image = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)

translation_center = model.init_manifold[1].gd.detach().flatten().tolist()
translation_moment = model.init_manifold[1].cotan.detach().flatten().tolist()
translation_center_end = intermediates['states'][-1][1].gd.flatten().tolist()

print(translation_center)
print(translation_center_end)
print(translation_moment)

In [None]:
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.title("Source image")
plt.imshow(source_image, origin='lower')
plt.plot(center.flatten().tolist()[0], center.flatten().tolist()[1], 'X')

plt.subplot(1, 3, 2)
plt.title("Fitted image")
plt.imshow(deformed_image, origin='lower')
plt.plot(translation_center[0], translation_center[1], 'X')
plt.plot(translation_center_end[0], translation_center_end[1], 'X')
plt.quiver(translation_center[0], translation_center[1],
           translation_moment[0], translation_moment[1])

plt.subplot(1, 3, 3)
plt.title("target image")
plt.imshow(target_image, origin='lower')

plt.show()