In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import math
import scipy
import torch
import matplotlib.pyplot as plt

import implicitmodules.torch as dm

In [2]:
source_image = dm.Utilities.load_greyscale_image("../../data/images/bar_a.png", origin='lower')
target_image = dm.Utilities.load_greyscale_image("../../data/images/bar_b.png", origin='lower')

In [3]:
%matplotlib qt5
plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(source_image, origin='lower')

plt.subplot(1, 2, 2)
plt.title("Target image")
plt.imshow(target_image, origin='lower')

plt.show()

In [53]:
sig_smooth = 10
im0 = torch.tensor(scipy.ndimage.gaussian_filter(source_image, sig_smooth))
im1 = torch.tensor(scipy.ndimage.gaussian_filter(target_image, sig_smooth))

source = dm.Models.DeformableImage(im0, extent=None, output='points')
target = dm.Models.DeformableImage(im1, extent=None, output='points')

In [12]:
deformed_source = dm.Utilities.deformed_intensities(source.points, source.bitmap, source.extent)

%matplotlib qt5

plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(source.bitmap, origin='lower', extent=source.extent)
plt.plot(center[0, 0].numpy(), center[0, 1].numpy(), 'x')

plt.subplot(1, 2, 2)
plt.title("Deformed source image")
plt.imshow(deformed_source, origin='lower', extent=target.extent)

plt.show()

In [20]:
print(source.extent)
center = torch.tensor([[0.275, 0.425]])*torch.tensor([[source.extent.xmax, source.extent.ymax]])
print(center)

Utilities.AABB {'xmin': 0, 'ymin': 0.0, 'xmax': 1.0, 'ymax': 1.0}
tensor([[0.2750, 0.4250]])


In [54]:
%matplotlib qt5

plt.subplot(1, 2, 1)
plt.title("Source image")
plt.imshow(source.bitmap, origin='lower', extent=source.extent)
plt.plot(center[0, 0].numpy(), center[0, 1].numpy(), 'x')

plt.subplot(1, 2, 2)
plt.title("Target image")
plt.imshow(target.bitmap, origin='lower', extent=target.extent)

plt.show()

In [64]:
from imageio import imread
from torch.nn.functional import avg_pool2d
import numpy as np

dtype = torch.get_default_dtype()

def grid(W):
    y, x = torch.meshgrid( [ torch.arange(0.,W).type(dtype) / W ] * 2 )
    return torch.stack( (x,y), dim=2 ).view(-1,2)

def load_image(fname) :
    img = np.mean( imread(fname), axis=2)  # Grayscale
    img = (img[:, :])  / 255.
    return 1 - img                         # black = 1, white = 0

def as_measure(fname, size, sig_smooth):
    img = load_image(fname)
    weights = torch.tensor(scipy.ndimage.gaussian_filter(img, sig_smooth), dtype=dtype)
    sampling = weights.shape[0] // size
    weights = avg_pool2d( weights.unsqueeze(0).unsqueeze(0), sampling).squeeze(0).squeeze(0)
    weights = weights / weights.sum()

    samples = grid( size )
    return samples, weights.view(-1)

s = as_measure("../../data/images/bar_a.png", 200, 30.)
t = as_measure("../../data/images/bar_b.png", 200, 30.)
print(s[1].shape)

loss = dm.Attachment.GeomlossAttachment(loss='sinkhorn', blur=1., scaling=0.9)
l = loss(s, t)
print(l)

torch.Size([40000])


tensor(0.0993)


In [21]:
translation = dm.DeformationModules.ImplicitModule0(2, 1, 0.9, nu=0.5, gd=center.clone().requires_grad_())

In [22]:
# model = dm.Models.RegistrationModel(source, [translation], dm.Attachment.EuclideanPointwiseDistanceAttachment(), fit_gd=[False], lam=100.)
model = dm.Models.RegistrationModel(source, [translation], dm.Attachment.GeomlossAttachment(loss='sinkhorn', blur=0.05, scaling=0.9), fit_gd=[False], lam=100.)


In [23]:
shoot_solver='rk4'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

In [24]:
fitter.fit(target, 100, costs=costs, options={'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

Compiling libKeOpstorchdc0db9a518 in /home/leander/.cache/pykeops-1.3-cpython-36/build-libKeOpstorchdc0db9a518:
       formula: Max_SumShiftExp_Reduction(( B - (P * (SqDist(X,Y) / IntCst(2)) ) ),0)
       aliases: X = Vi(0,2); Y = Vj(1,2); B = Vj(2,1); P = Pm(3,1); 
       dtype  : float32
... 

Done.


Starting optimization with method torch LBFGS
Initial cost={'deformation': 0.0, 'attach': nan}


Compiling libKeOpstorch8a78b869ab in /home/leander/.cache/pykeops-1.3-cpython-36/build-libKeOpstorch8a78b869ab:
       formula: Grad_WithSavedForward(Max_SumShiftExp_Reduction(( B - (P * (SqDist(X,Y) / IntCst(2)) ) ),0), Var(0,2,0), Var(4,2,0), Var(5,2,0))
       aliases: X = Vi(0,2); Y = Vj(1,2); B = Vj(2,1); P = Pm(3,1); Var(4,2,0); Var(5,2,0); 
       dtype  : float32
... 

Done.


ValueError: arange: cannot compute length

In [12]:
intermediates = {}
with torch.autograd.no_grad():
    deformed_image = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0]

translation_center = model.init_manifold[1].gd.detach().flatten().tolist()
translation_moment = model.init_manifold[1].cotan.detach().flatten().tolist()
translation_center_end = intermediates['states'][-1][1].gd.flatten().tolist()

print(translation_center)
print(translation_center_end)
print(translation_moment)

torch.Size([40000, 2])
tensor([[  0.0000,   0.0000],
        [  0.9950,   0.0000],
        [  1.9900,   0.0000],
        ...,
        [196.0150, 198.0050],
        [197.0100, 198.0050],
        [198.0050, 198.0050]])


[54.72500228881836, 84.57500457763672]
[151.8264617919922, 118.12496948242188]
[0.005146985873579979, 0.001814595190808177]


torch.Size([40000, 2])
tensor([[  0.0000,   0.0000],
        [  0.9950,   0.0000],
        [  1.9900,   0.0000],
        ...,
        [196.0150, 198.0050],
        [197.0100, 198.0050],
        [198.0050, 198.0050]])


[54.72500228881836, 84.57500457763672]
[151.8264617919922, 118.12496948242188]
[0.005146985873579979, 0.001814595190808177]


In [15]:
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.title("Source image")
plt.imshow(source_image, origin='lower', extent=source.extent)
plt.plot(center.flatten().tolist()[0], center.flatten().tolist()[1], 'X')

plt.subplot(1, 3, 2)
plt.title("Fitted image")
plt.imshow(deformed_image, origin='lower', extent=source.extent)
plt.plot(translation_center[0], translation_center[1], 'X')
plt.plot(translation_center_end[0], translation_center_end[1], 'X')
plt.quiver(translation_center[0], translation_center[1],
           translation_moment[0], translation_moment[1], scale=0.001)

plt.subplot(1, 3, 3)
plt.title("target image")
plt.imshow(target_image, origin='lower', extent=target.extent)

plt.show()