In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import math
import copy

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

dm.Utilities.set_compute_backend('keops')

In [None]:
# WIP
# On va merge les deux .pkl en un seul dictionnary
data_source = pickle.load(open("../../data/diffuse.pkl", 'rb'))
data_target = pickle.load(open("../../data/diffuset.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 32.
height_target = 136.

source = torch.tensor(data_source[1]).type(torch.get_default_dtype())
target = torch.tensor(data_target[1]).type(torch.get_default_dtype())

smin, smax = torch.min(source[:, 1]), torch.max(source[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source[:, 0]))

tmin, tmax = torch.min(target[:, 1]), torch.max(target[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target[:, 0]))

pos_source = source[source[:, 2] == 2, 0:2]
pos_implicit0 = source[source[:, 2] == 1, 0:2]
pos_implicit1 = source[source[:, 2] == 1, 0:2]
pos_target = target[target[:, 2] == 2, 0:2]

aabb = dm.Utilities.AABB.build_from_points(pos_target)
aabb.squared()

In [None]:
# Some plots
%matplotlib qt5

plt.subplot(2, 2, 1)
plt.title("Source")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')
plt.axis('equal')

plt.subplot(2, 2, 2)
plt.title("Target")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.axis('equal')

plt.subplot(2, 2, 3)
plt.imshow(data_source[0])

plt.subplot(2, 2, 4)
plt.imshow(data_target[0])

plt.show()

In [None]:
# Setting up the modules

# Global translation module
global_translation = dm.DeformationModules.GlobalTranslation(2)

# Local translation module
sigma0 = 15.
nu0 = 0.001
coeff0 = 100.
implicit0 = dm.DeformationModules.ImplicitModule0(2, pos_implicit0.shape[0], sigma0, nu0, coeff0, gd=pos_implicit0.clone().requires_grad_())

# Elastic module
sigma1 = 15.
nu1 = 0.001
coeff1 = 0.01
K = 10
C = K * torch.ones(pos_implicit1.shape[0], 2, 1)

th = 0. * math.pi * torch.ones(pos_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(t) for t in th])

implicit1 = dm.DeformationModules.ImplicitModule1(2, pos_implicit1.shape[0], sigma1, C, nu1, coeff1, gd=(pos_implicit1.clone().requires_grad_(), R.clone().requires_grad_()))

In [None]:
# Setting up the model and start the fitting loop
model = dm.Models.ModelPointsRegistration([pos_source], [global_translation, implicit0, implicit1], [dm.Attachment.VarifoldAttachment(2, [10., 50.], backend='torch')], lam=100.)
fitter = dm.Models.ModelFittingScipy(model)

In [None]:
costs = fitter.fit([pos_target], 50, log_interval=10)

In [None]:
# Results
modules = dm.DeformationModules.CompoundModule(copy.copy(model.modules))
modules.manifold.fill(model.init_manifold)
dm.HamiltonianDynamic.shoot(dm.HamiltonianDynamic.Hamiltonian(modules), 10, 'euler')
out = modules.manifold[0].gd.detach().numpy()
shot_implicit0 = modules.manifold[2].gd.detach().numpy()
shot_implicit1 = modules.manifold[3].gd[0].detach().numpy()

%matplotlib qt5
plt.subplot(1, 3, 1)
plt.title("Source")
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')
plt.axis('equal')

plt.subplot(1, 3, 2)
plt.title("Deformed source")
plt.plot(out[:, 0], out[:, 1], '-')
plt.plot(shot_implicit0[:, 0], shot_implicit0[:, 1], 'x')
plt.plot(shot_implicit1[:, 0], shot_implicit1[:, 1], '.')
plt.axis('equal')

plt.subplot(1, 3, 3)
plt.title("Deformed source and target")
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.plot(out[:, 0], out[:, 1], '-')
plt.axis('equal')
plt.show()

In [None]:
# Evolution of the cost with iterations
plt.title("Cost")
plt.xlabel("Iteration(s)")
plt.ylabel("Cost")
plt.plot(range(len(costs)), costs, lw=0.8)
plt.grid()
plt.show()