In [None]:
%reset
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.spatial import ConvexHull

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
data = pickle.load(open("../../data/basipetal.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 38.
height_target = 100.

source = torch.tensor(data['source_d'], dtype=torch.get_default_dtype())
target = torch.tensor(data['target_d'], dtype=torch.get_default_dtype())
source_curve = torch.tensor(data['source_c'], dtype=torch.get_default_dtype())
target_curve = torch.tensor(data['target_c'], dtype=torch.get_default_dtype())

smin, smax = torch.min(source_curve[:, 1]), torch.max(source_curve[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source_curve[:, 0]))
source_curve[:, 1] = Dy - sscale * (source_curve[:, 1] - smax)
source_curve[:, 0] = Dx + sscale * (source_curve[:, 0] - torch.mean(source_curve[:, 0]))

tmin, tmax = torch.min(target_curve[:, 1]), torch.max(target_curve[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target_curve[:, 0]))
target_curve[:, 1] = - tscale * (target_curve[:, 1] - tmax)
target_curve[:, 0] = tscale * (target_curve[:, 0] - torch.mean(target_curve[:, 0]))

source_curve_fit = source_curve.clone()
source_curve_fit = source_curve_fit[source_curve_fit[:, 0] <= 0]
target_curve_fit = target_curve.clone()
target_curve_fit = target_curve_fit[target_curve_fit[:, 0] <= 0]

hull_fit = ConvexHull(source_curve_fit)
source_curve_fit_convex = source_curve_fit[hull_fit.vertices]
hull = ConvexHull(source_curve)
source_curve_convex = source_curve[hull.vertices]

aabb = dm.Utilities.AABB.build_from_points(torch.cat([target_curve_fit, source_curve_fit]))
aabb_source_fit = dm.Utilities.AABB.build_from_points(source_curve_fit)
aabb_source = dm.Utilities.AABB.build_from_points(source_curve)
aabb.squared()

In [None]:
pts_implicit1_step = 1.5
pts_implicit1_x, pts_implicit1_y = torch.meshgrid([
    torch.arange(0., aabb_source_fit.xmin-pts_implicit1_step, step=-pts_implicit1_step),
    torch.arange(aabb_source_fit.ymin, aabb_source_fit.ymax+pts_implicit1_step, step=pts_implicit1_step)])

pts_implicit1 = dm.Utilities.grid2vec(pts_implicit1_x, pts_implicit1_y)
pts_implicit1_mask = dm.Utilities.is_inside_shape(source_curve_fit_convex, pts_implicit1)
pts_implicit1 = pts_implicit1[pts_implicit1_mask]

In [None]:
%matplotlib qt5

plt.axis('equal')

plt.subplot(1, 2, 1)
plt.axis(aabb.get_list())
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(source_curve_fit[:, 0].numpy(), source_curve_fit[:, 1].numpy(), '-')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), 'x')
plt.plot(pts_implicit1[:, 0].numpy(), pts_implicit1[:, 1].numpy(), '.')

plt.subplot(1, 2, 2)
plt.axis(aabb.get_list())
plt.title('Target')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(target_curve_fit[:, 0].numpy(), target_curve_fit[:, 1].numpy(), '-')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')

plt.show()

In [None]:
sigma1 = 25.
nu1 = 0.001
coeff1 = 0.001
C = torch.ones(pts_implicit1.shape[0], 2, 1, requires_grad=True)
th = 0. * math.pi * torch.ones(pts_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(t) for t in th])
implicit1 = dm.DeformationModules.ImplicitModule1(dm.Manifolds.Stiefel(2, pts_implicit1.shape[0], gd=(pts_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C, sigma1, nu1, coeff1)
global_trans = dm.DeformationModules.GlobalTranslation(2, coeff=0.001)

In [None]:
# Setting up the model and start the fitting loop
model = dm.Models.ModelCompoundWithPointsRegistration(
    [source_curve_fit, source],
    [implicit1, global_trans],
    [dm.Attachment.VarifoldAttachement([10., 50.], 0.1),
     dm.Attachment.EuclideanPointwiseDistanceAttachement(50.)],
    other_parameters=[implicit1.C])

In [None]:
costs = model.fit([target_curve_fit, target], max_iter=750, l=10., step_length=1., options={'max_ls': 100, 'damping': True, 'c0': 1e-4, 'c1': 0.85})

In [None]:
deformation_costs, attach_costs, total_costs = zip(*costs)
plt.plot(range(len(costs)), total_costs)
plt.plot(range(len(costs)), deformation_costs)
plt.plot(range(len(costs)), attach_costs)
plt.show()

In [None]:
%matplotlib qt5
from matplotlib.patches import Ellipse
C = model.modules[2].C.detach()
torch.set_printoptions(precision=12)

points = model.modules[1].manifold.gd.detach().view(-1, 2)
out_curve = model.modules[0].manifold.gd.detach().view(-1, 2)
ax = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.')
plt.plot(points[:, 0].numpy(), points[:, 1].numpy(), 'o')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
plt.plot(target_curve_fit[:, 0].numpy(), target_curve_fit[:, 1].numpy(), '--')
plt.plot(out_curve[:, 0].numpy(), out_curve[:, 1].numpy(), '--')
for i in range(points.shape[0]):
    plt.plot([target[i, 0], points[i, 0]], [target[i, 1], points[i, 1]], '-')

def_grids_c = model.compute_deformation_grid(torch.tensor([aabb_source.xmin, aabb_source.ymin]),
                                             torch.tensor([aabb_source.width/2, aabb_source.height]),
                                             torch.tensor([8, 16]))
def_grid_c_x, def_grid_c_y = def_grids_c[-1][0], def_grids_c[-1][1]
dm.Utilities.plot_grid(ax, def_grid_c_x.numpy(), def_grid_c_y.numpy(), color='C0')

for i in range(source.shape[0]):
    C_i = C[i, :, 0]
    ell = Ellipse(xy=source[i], width=abs(C_i[0]), height=abs(C_i[1]), angle=0.)
    ax.add_artist(ell)

plt.show()

In [None]:
source_mirrored = pts_implicit1.detach().clone()
source_mirrored[:, 0] = -source_mirrored[:, 0]
source_c = torch.cat([pts_implicit1.detach().clone(), source_mirrored])
C_c = torch.cat([C.detach(), C.detach()])
R_c = torch.cat([R, R])

pts_implicit0_s = source_curve.detach().clone()
m_scale = 5.4
pts_implicit0_m_x, pts_implicit0_m_y = torch.meshgrid([
    torch.arange(aabb_source.xmin, aabb_source.xmax, step=m_scale),
    torch.arange(aabb_source.ymin, aabb_source.ymax, step=m_scale)])
pts_implicit0_m = dm.Utilities.grid2vec(pts_implicit0_m_x, pts_implicit0_m_y)
pts_implicit0_m = pts_implicit0_m[dm.Utilities.is_inside_shape(source_curve_convex, pts_implicit0_m)]

In [None]:
plt.axis('equal')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy())
plt.plot(source_c[:, 0].numpy(), source_c[:, 1].numpy(), '.')
plt.plot(pts_implicit0_s[:, 0].numpy(), pts_implicit0_s[:, 1].numpy(), 'x')
plt.plot(pts_implicit0_m[:, 0].numpy(), pts_implicit0_m[:, 1].numpy(), 'o')
plt.show()

In [None]:
implicit0_s = dm.DeformationModules.ImplicitModule0(dm.Manifolds.Landmarks(2, pts_implicit0_s.shape[0], gd=pts_implicit0_s.view(-1).requires_grad_()), 1., 0.001, 5.)
implicit0_m = dm.DeformationModules.ImplicitModule0(dm.Manifolds.Landmarks(2, pts_implicit0_m.shape[0], gd=pts_implicit0_m.view(-1).requires_grad_()), m_scale, 0.001, 5.)
implicit1_c = dm.DeformationModules.ImplicitModule1(dm.Manifolds.Stiefel(2, source_c.shape[0], gd=(source_c.view(-1).requires_grad_(), R_c.view(-1).requires_grad_())), C_c, sigma1, nu1, coeff1)
global_trans_c = dm.DeformationModules.ImplicitModule0.build_from_points(2, 1, 1000., 0.001, 0.001, gd=torch.zeros(2, requires_grad=True))

In [None]:
# Setting up the model and start the fitting loop
model_c = dm.Models.ModelCompoundWithPointsRegistration((source_curve, torch.ones(source_curve.shape[0])), [implicit1_c, implicit0_s, implicit0_m, global_trans_c], [True, True, True, True], dm.Attachment.VarifoldAttachement([10., 50.]))
costs_c = model_c.fit((target_curve, torch.ones(target_curve.shape[0])), max_iter=500, l=1., lr=0.05, log_interval=1)

In [None]:
plt.plot(range(len(costs_c)), costs_c)
plt.show()

In [None]:
# Some plots
%matplotlib qt5

fit_curve = model_c.modules[0].manifold.gd.detach().view(-1, 2)
imp_points = model_c.modules[1].manifold.gd[0].detach().view(-1, 2)

aabb = dm.Utilities.AABB.build_from_points(torch.cat([target_curve, fit_curve]))
def_grids = model_c.compute_deformation_grid(torch.tensor([aabb_source.xmin, aabb_source.ymin]),
                                             torch.tensor([aabb_source.width, aabb_source.height]),
                                             torch.tensor([16, 16]))

ax_c = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(source_c[:, 0].numpy(), source_c[:, 1].numpy(), '.')
dm.Utilities.plot_grid(ax_c, def_grids[-1][0].numpy(), def_grids[-1][1].numpy(), color='C0')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy(), '-')
plt.plot(target_curve[:, 0].numpy(), target_curve[:, 1].numpy(), '-')
plt.plot(fit_curve[:, 0].numpy(), fit_curve[:, 1].numpy(), '--')
plt.plot(imp_points[:, 0].numpy(), imp_points[:, 1].numpy(), 'x')

plt.show()

In [None]:
# We apply some deformation on the leaf shape
# Some additive gaussian noise
torch.manual_seed(1337)
target_curve_deformed = target_curve.clone() + 0.8*torch.randn_like(target_curve)
# Small translation
target_curve_deformed = target_curve_deformed + torch.tensor([5., 5.]).repeat([target_curve.shape[0], 1])
# Small rotation
r = dm.Utilities.rot2d(math.pi/16.)
for i in range(target_curve_deformed.shape[0]):
    target_curve_deformed[i] = torch.mm(target_curve_deformed[i].unsqueeze(0), r)

aabb_deformed = dm.Utilities.AABB.build_from_points(target_curve_deformed)

In [None]:
plt.axis('equal')
plt.plot(target_curve_deformed[:, 0].numpy(), target_curve_deformed[:, 1].numpy())
plt.show()

In [None]:
implicit0_s = dm.DeformationModules.ImplicitModule0(dm.Manifolds.Landmarks(2, pts_implicit0_s.shape[0], gd=pts_implicit0_s.view(-1).requires_grad_()), 1., 0.001, 5.)
implicit0_m = dm.DeformationModules.ImplicitModule0(dm.Manifolds.Landmarks(2, pts_implicit0_m.shape[0], gd=pts_implicit0_m.view(-1).requires_grad_()), m_scale, 0.001, 5.)
implicit1_c = dm.DeformationModules.ImplicitModule1(dm.Manifolds.Stiefel(2, source_c.shape[0], gd=(source_c.view(-1).requires_grad_(), R_c.view(-1).requires_grad_())), C_c, sigma1, nu1, coeff1)
global_trans_c = dm.DeformationModules.ImplicitModule0.build_from_points(2, 1, 1000., 0.001, 0.001, gd=torch.zeros(2, requires_grad=True))

In [None]:
# Setting up the model and start the fitting loop
model_deformed = dm.Models.ModelCompoundWithPointsRegistration((source_curve, torch.ones(source_curve.shape[0])), [implicit1_c, implicit0_s, implicit0_m, global_trans_c], [True, True, True, True], dm.Attachment.VarifoldAttachement([10., 50.]))
costs_deformed = model_deformed.fit((target_curve_deformed, torch.ones(target_curve_deformed.shape[0])), max_iter=500, l=1., lr=0.1, log_interval=1)

In [None]:
plt.plot(range(len(costs_deformed)), costs_deformed)
plt.show()

In [None]:
# Some plots
%matplotlib qt5

fit_curve = model_deformed.modules[0].manifold.gd.detach().view(-1, 2)

aabb = dm.Utilities.AABB.build_from_points(torch.cat([target_curve, fit_curve]))
def_grids_deformed = model_deformed.compute_deformation_grid(torch.tensor([aabb_source.xmin, aabb_source.ymin]),
                                                             torch.tensor([aabb_source.width, aabb_source.height]),
                                                             torch.tensor([32, 32]))

ax_deformed = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
dm.Utilities.plot_grid(ax_deformed, def_grids_deformed[-1][0].numpy(), def_grids_deformed[-1][1].numpy(), color='C0')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy(), '-')
plt.plot(target_curve_deformed[:, 0].numpy(), target_curve_deformed[:, 1].numpy(), '-')
plt.plot(fit_curve[:, 0].numpy(), fit_curve[:, 1].numpy(), '--')

plt.show()