In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.spatial import ConvexHull

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)
torch.set_printoptions(precision=12)

In [None]:
data = pickle.load(open("../../data/basipetal.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 38.
height_target = 100.

source = torch.tensor(data['source_d'], dtype=torch.get_default_dtype())
target = torch.tensor(data['target_d'], dtype=torch.get_default_dtype())
source_curve = torch.tensor(data['source_c'], dtype=torch.get_default_dtype())
target_curve = torch.tensor(data['target_c'], dtype=torch.get_default_dtype())

smin, smax = torch.min(source_curve[:, 1]), torch.max(source_curve[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source_curve[:, 0]))
source_curve[:, 1] = Dy - sscale * (source_curve[:, 1] - smax)
source_curve[:, 0] = Dx + sscale * (source_curve[:, 0] - torch.mean(source_curve[:, 0]))

tmin, tmax = torch.min(target_curve[:, 1]), torch.max(target_curve[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target_curve[:, 0]))
target_curve[:, 1] = - tscale * (target_curve[:, 1] - tmax)
target_curve[:, 0] = tscale * (target_curve[:, 0] - torch.mean(target_curve[:, 0]))

source_curve_fit = source_curve.clone()
source_curve_fit = source_curve_fit[source_curve_fit[:, 0] <= 0]
target_curve_fit = target_curve.clone()
target_curve_fit = target_curve_fit[target_curve_fit[:, 0] <= 0]

hull_fit = ConvexHull(source_curve_fit)
source_curve_fit_convex = source_curve_fit[hull_fit.vertices]
hull = ConvexHull(source_curve)
source_curve_convex = source_curve[hull.vertices]

aabb = dm.Utilities.AABB.build_from_points(torch.cat([target_curve_fit, source_curve_fit]))
aabb_source_fit = dm.Utilities.AABB.build_from_points(source_curve_fit)
aabb_source = dm.Utilities.AABB.build_from_points(source_curve)
aabb.squared()

In [None]:
pts_implicit1_step = 1.5
pts_implicit1_x, pts_implicit1_y = torch.meshgrid([
    torch.arange(0., aabb_source_fit.xmin-pts_implicit1_step, step=-pts_implicit1_step),
    torch.arange(aabb_source_fit.ymin, aabb_source_fit.ymax+pts_implicit1_step, step=pts_implicit1_step)])

pts_implicit1 = dm.Utilities.grid2vec(pts_implicit1_x, pts_implicit1_y)
pts_implicit1_mask = dm.Utilities.area_shape(pts_implicit1, shape=source_curve_fit)
pts_implicit1 = pts_implicit1[pts_implicit1_mask]

In [None]:
%matplotlib qt5

plt.axis('equal')

plt.subplot(1, 2, 1)
plt.title("Source")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(source_curve_fit[:, 0].numpy(), source_curve_fit[:, 1].numpy(), '-')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), 'x')
plt.plot(pts_implicit1[:, 0].numpy(), pts_implicit1[:, 1].numpy(), '.')
plt.axis('equal')

plt.subplot(1, 2, 2)
plt.title("Target")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(target_curve_fit[:, 0].numpy(), target_curve_fit[:, 1].numpy(), '-')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
plt.axis('equal')

plt.show()

In [None]:
sigma1 = 25.
nu1 = 0.001
coeff_global_trans = 1.
coeff1 = 1.
C = torch.ones(pts_implicit1.shape[0], 2, 1, requires_grad=True)
th = 0. * math.pi * torch.ones(pts_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(t) for t in th])
implicit1 = dm.DeformationModules.ImplicitModule1(2, pts_implicit1.shape[0], sigma1, C, nu1, coeff1, gd=(pts_implicit1.clone().requires_grad_(), R.clone().requires_grad_()))
global_trans = dm.DeformationModules.GlobalTranslation(2, coeff=coeff_global_trans)

In [None]:
# Setting up the model and start the fitting loop
model = dm.Models.ModelPointsRegistration(
    [source_curve_fit.clone(), source.clone()],
    [implicit1, global_trans],
    [dm.Attachment.VarifoldAttachment(2, [10., 50.], 0.1),
     dm.Attachment.EuclideanPointwiseDistanceAttachment(50.)],
    other_parameters=[('C', [implicit1.C])], lam=1000.)

In [None]:
fitter = dm.Models.ModelFittingScipy(model)

costs = fitter.fit([target_curve_fit, target], 500, options={}, disp=True)

In [None]:
%matplotlib qt5
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
%matplotlib qt5
C = model.modules[2].C.detach()
R = model.modules[2].manifold.gd[1].detach()

points = model.modules[1].manifold.gd.detach().view(-1, 2)
out_curve = model.modules[0].manifold.gd.detach().view(-1, 2)
ax = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.')
plt.plot(points[:, 0].numpy(), points[:, 1].numpy(), 'o')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
plt.plot(target_curve_fit[:, 0].numpy(), target_curve_fit[:, 1].numpy(), '--')
plt.plot(out_curve[:, 0].numpy(), out_curve[:, 1].numpy(), '--')
for i in range(points.shape[0]):
   plt.plot([target[i, 0], points[i, 0]], [target[i, 1], points[i, 1]], '-')

def_grids_c = model.compute_deformation_grid([aabb_source.xmin, aabb_source.ymin],
                                             [aabb_source.width/2, aabb_source.height],
                                             [8, 16])

def_grid_c_x, def_grid_c_y = def_grids_c[0], def_grids_c[1]
dm.Utilities.plot_grid(ax, def_grid_c_x.numpy(), def_grid_c_y.numpy(), color='C0')

# dm.Utilities.plot_C_arrow(ax, source, C, R=R, scale=1., zorder=3, mutation_scale=10)
# dm.Utilities.plot_C_ellipse(ax, pts_implicit1, C, R=R, scale=1., color='black')

plt.xlabel("$x$")
plt.ylabel("$y$")

plt.show()