In [None]:
%reset
%load_ext autoreload
%autoreload 2

import time

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch
import geomloss

import defmod as dm

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
source_image = dm.sampling.load_greyscale_image("../data/heart_a.png")
target_image = dm.sampling.load_greyscale_image("../data/heart_b.png")

In [None]:
aabb = dm.usefulfunctions.AABB(0., source_image.shape[0], 0., source_image.shape[1])
sigma = 7.
step = 0.5*sigma
x, y = torch.meshgrid([torch.arange(aabb.xmin, aabb.xmax, step=step), torch.arange(aabb.ymin, aabb.ymax, step=step)])

gd = dm.usefulfunctions.grid2vec(x, y).contiguous()

landmarks = dm.manifold.Landmarks(2, gd.shape[0], gd=gd.view(-1))

trans = dm.implicitmodules.ImplicitModule0(landmarks, sigma, 0.)

In [None]:
plt.imshow(source_image)
plt.scatter(gd.view(-1, 2)[:, 0].numpy(), gd.view(-1, 2)[:, 1].numpy())

plt.show()

In [None]:
import numpy as np
import scipy.ndimage.filters as fi
def gkern2(kernlen=21, nsig=3):
    """Returns a 2D Gaussian kernel array."""

    # create nxn zeros
    inp = np.zeros((kernlen, kernlen))
    # set element at the middle to one, a dirac delta
    inp[kernlen//2, kernlen//2] = 1
    # gaussian-smooth the dirac, resulting in a gaussian filter mask
    return fi.gaussian_filter(inp, nsig)

def gaussian_filtering(img):
    kr = 50
    kd = kr*2+1
    sigma = 10
    frame_res = img.shape
    kernel = torch.tensor(gkern2(kd, sigma).astype(np.float32)).reshape(1, 1, kd, kd)
    return torch.nn.functional.conv2d(img.reshape(1, 1, frame_res[0], frame_res[1]), kernel, stride=1, padding=kr).reshape(frame_res)


In [None]:
my_model = dm.models.ModelCompoundImageRegistration(source_image, [trans], [True], geomloss.SamplesLoss("sinkhorn", p=1))
start_time = time.clock()
costs = my_model.fit(target_image, lr=2e-5, l=1000., max_iter=500, log_interval=1)
print("Elapsed time:", time.clock() - start_time)

In [None]:
it = 5
sampled_out = my_model()
grid_x, grid_y = my_model.compute_deformation_grid(torch.tensor([0., 0.]), torch.tensor([32., 32.]), torch.Size([16, 16]), it=it, intermediate=True)

%matplotlib qt5
plt.subplot(1, 3, 1)
plt.imshow(source_image, cmap='gray')
ax = plt.subplot(1, 3, 2)
plt.imshow(sampled_out.detach().numpy(), cmap='gray')
dm.usefulfunctions.plot_grid(ax, grid_x.numpy(), grid_y.numpy(), color='C0')
plt.axis([0., sampled_out.shape[0]-1, sampled_out.shape[1]-1, 0.])
plt.subplot(1, 3, 3)
plt.imshow(target_image, cmap='gray')

plt.show()

In [None]:
plt.plot(range(0, len(costs)), costs)
plt.show()