In [None]:
%reset
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.spatial import ConvexHull

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
data = pickle.load(open("../../data/basipetal.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 38.
height_target = 100.

source = torch.tensor(data['source_d'])
target = torch.tensor(data['target_d'])
source_curve = torch.tensor(data['source_c'])
target_curve = torch.tensor(data['target_c'])

smin, smax = torch.min(source_curve[:, 1]), torch.max(source_curve[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source_curve[:, 0]))
source_curve[:, 1] = Dy - sscale * (source_curve[:, 1] - smax)
source_curve[:, 0] = Dx + sscale * (source_curve[:, 0] - torch.mean(source_curve[:, 0]))

tmin, tmax = torch.min(target_curve[:, 1]), torch.max(target_curve[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target_curve[:, 0]))
target_curve[:, 1] = - tscale * (target_curve[:, 1] - tmax)
target_curve[:, 0] = tscale * (target_curve[:, 0] - torch.mean(target_curve[:, 0]))

source_curve_fit = source_curve.clone()
source_curve_fit = source_curve_fit[source_curve_fit[:, 0] <= 0]
target_curve_fit = target_curve.clone()
target_curve_fit = target_curve_fit[target_curve_fit[:, 0] <= 0]

hull_fit = ConvexHull(source_curve_fit)
source_curve_fit_convex = source_curve_fit[hull_fit.vertices]
hull = ConvexHull(source_curve)
source_curve_convex = source_curve[hull.vertices]

aabb = dm.usefulfunctions.AABB.build_from_points(torch.cat([target_curve_fit, source_curve_fit]))
aabb_source_fit = dm.usefulfunctions.AABB.build_from_points(source_curve_fit)
aabb_source = dm.usefulfunctions.AABB.build_from_points(source_curve)
aabb.squared()

In [None]:
pts_implicit1_step = 1.5
pts_implicit1_x, pts_implicit1_y = torch.meshgrid([
    torch.arange(0., aabb_source_fit.xmin-pts_implicit1_step, step=-pts_implicit1_step),
    torch.arange(aabb_source_fit.ymin, aabb_source_fit.ymax+pts_implicit1_step, step=pts_implicit1_step)])

pts_implicit1 = dm.usefulfunctions.grid2vec(pts_implicit1_x, pts_implicit1_y)
pts_implicit1_mask = dm.usefulfunctions.is_inside_shape(source_curve_fit_convex, pts_implicit1)
pts_implicit1 = pts_implicit1[pts_implicit1_mask]

In [None]:
%matplotlib qt5

plt.axis('equal')

plt.subplot(1, 2, 1)
plt.axis(aabb.get_list())
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(source_curve_fit[:, 0].numpy(), source_curve_fit[:, 1].numpy(), '-')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), 'x')
plt.plot(pts_implicit1[:, 0].numpy(), pts_implicit1[:, 1].numpy(), '.')

plt.subplot(1, 2, 2)
plt.axis(aabb.get_list())
plt.title('Target')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(target_curve_fit[:, 0].numpy(), target_curve_fit[:, 1].numpy(), '-')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')

plt.show()

In [None]:
sigma1 = 25.
nu1 = 0.001
coeff1 = 0.001
C = torch.ones(pts_implicit1.shape[0], 2, 1, requires_grad=True)
th = 0. * math.pi * torch.ones(pts_implicit1.shape[0])
R = torch.stack([dm.usefulfunctions.rot2d(t) for t in th])
implicit1 = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, pts_implicit1.shape[0], gd=(pts_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C, sigma1, nu1, coeff1)
global_trans = dm.globalmodules.GlobalTranslation(2, coeff=0.001)

In [None]:
# Setting up the model and start the fitting loop
model = dm.models.ModelCompoundWithPointsRegistration([(source_curve_fit, torch.zeros(source_curve_fit.shape[0])), (source, torch.ones(source.shape[0]))], [implicit1, global_trans], [True, True], [dm.attachement.VarifoldAttachement([10., 50.]), dm.attachement.PointwiseDistanceAttachement()], parameters=[implicit1.C])
costs = model.fit([(target_curve_fit, torch.ones(target_curve_fit.shape[0])), (target, torch.ones(target.shape[0]))], max_iter=2000, l=1., lr=0.1, log_interval=10)

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
%matplotlib qt5
from matplotlib.patches import Ellipse

C = model.modules[2].C.detach()
points = model.modules[1].manifold.gd.detach().view(-1, 2)
ax = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.')
plt.plot(points[:, 0].numpy(), points[:, 1].numpy(), 'o')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
for i in range(points.shape[0]):
    plt.plot([target[i, 0], points[i, 0]], [target[i, 1], points[i, 1]], '-')

def_grids_c = model.compute_deformation_grid(torch.tensor([aabb.xmin, aabb.ymin]),
                                                        torch.tensor([aabb.width, aabb.height]),
                                                          torch.tensor([32, 32]))
def_grid_c_x, def_grid_c_y = def_grids_c[-1][0], def_grids[-1][1]
dm.usefulfunctions.plot_grid(ax, def_grid_c_x.numpy(), def_grid_c_y.numpy(), color='C0')

epsix, epsiy = 0.0000001, 0.0000001
for i in range(source.shape[0]):
    C_i = C[i, :, 0]
    ell = Ellipse(xy=source[i], width=C_i[0] + epsix, height=C_i[1] + epsiy, angle=0.)
    ax.add_artist(ell)

plt.show()

In [None]:
shoot_compound_c = dm.deformationmodules.CompoundModule(model.modules)
shoot_compound_c.manifold.fill(model.init_manifold.copy())
shot_c, controls_c = dm.shooting.shoot_euler(dm.hamiltonian.Hamiltonian(shoot_compound_c), 10)

In [None]:
out_shot_c = []
out_controls_c = []

for i in range(len(shot)):
    out_shot_c.append((shot_c[i][0].gd.detach().view(-1, 2).numpy(),
                       shot_c[i][1].gd.detach().view(-1, 2).numpy(),
                       (shot_c[i][2].gd[0].detach().view(-1, 2).numpy(),
                        shot_c[i][2].gd[1].detach().view(-1, 2).numpy())))
for i in range(len(controls)):
    out_controls_c.append((controls_c[i][0].detach().numpy(),
                           controls_c[i][1].detach().numpy(),
                           controls_c[i][2].detach().numpy(),
                           controls_c[i][3].detach().numpy()))

In [None]:
source_mirrored = pts_implicit1.detach().clone()
source_mirrored[:, 0] = -source_mirrored[:, 0]
source_c = torch.cat([pts_implicit1.detach().clone(), source_mirrored])
C_c = torch.cat([C.detach(), C.detach()])
R_c = torch.cat([R, R])

pts_implicit0_s = source_curve.detach().clone()
m_scale = 5.4
pts_implicit0_m_x, pts_implicit0_m_y = torch.meshgrid([
    torch.arange(aabb_source.xmin, aabb_source.xmax, step=m_scale),
    torch.arange(aabb_source.ymin, aabb_source.ymax, step=m_scale)])
pts_implicit0_m = dm.usefulfunctions.grid2vec(pts_implicit0_m_x, pts_implicit0_m_y)
pts_implicit0_m = pts_implicit0_m[dm.usefulfunctions.is_inside_shape(source_curve_convex, pts_implicit0_m)]

In [None]:
plt.axis('equal')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy())
plt.plot(source_c[:, 0].numpy(), source_c[:, 1].numpy(), '.')
plt.plot(pts_implicit0_s[:, 0].numpy(), pts_implicit0_s[:, 1].numpy(), 'x')
plt.plot(pts_implicit0_m[:, 0].numpy(), pts_implicit0_m[:, 1].numpy(), 'o')
plt.show()

In [None]:
implicit0_s = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, pts_implicit0_s.shape[0], gd=pts_implicit0_s.view(-1).requires_grad_()), 1., 0.001, 5.)
implicit0_m = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, pts_implicit0_m.shape[0], gd=pts_implicit0_m.view(-1).requires_grad_()), m_scale, 0.001, 5.)
implicit1_c = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, source_c.shape[0], gd=(source_c.view(-1).requires_grad_(), R_c.view(-1).requires_grad_())), C_c, sigma1, nu1, coeff1)
global_trans_c = dm.implicitmodules.ImplicitModule0.build_from_points(2, 1, 1000., 0.001, 0.001, gd=torch.zeros(2, requires_grad=True))

In [None]:
# Setting up the model and start the fitting loop
model_c = dm.models.ModelCompoundWithPointsRegistration((source_curve, torch.ones(source_curve.shape[0])), [implicit1_c, implicit0_s, implicit0_m, global_trans_c], [True, True, True, True], dm.attachement.VarifoldAttachement([10., 50.]))
costs_c = model_c.fit((target_curve, torch.ones(target_curve.shape[0])), max_iter=500, l=1., lr=0.1, log_interval=1)

In [None]:
plt.plot(range(len(costs_c)), costs_c)
plt.show()

In [None]:
# Some plots
%matplotlib qt5

fit_curve = model_c.modules[0].manifold.gd.detach().view(-1, 2)
imp_points = model_c.modules[1].manifold.gd[0].detach().view(-1, 2)

aabb = dm.usefulfunctions.AABB.build_from_points(torch.cat([target_curve, fit_curve]))
def_grids = model_c.compute_deformation_grid(torch.tensor([aabb.xmin, aabb.ymin]),
                                                        torch.tensor([aabb.width, aabb.height]),
                                                          torch.tensor([32, 32]))

ax_c = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(source_c[:, 0].numpy(), source_c[:, 1].numpy(), '.')
dm.usefulfunctions.plot_grid(ax_c, def_grids[-1][0].numpy(), def_grids[-1][1].numpy(), color='C0')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy(), '-')
plt.plot(target_curve[:, 0].numpy(), target_curve[:, 1].numpy(), '-')
plt.plot(fit_curve[:, 0].numpy(), fit_curve[:, 1].numpy(), '--')
plt.plot(imp_points[:, 0].numpy(), imp_points[:, 1].numpy(), 'x')

plt.show()

In [None]:
shoot_compound = dm.deformationmodules.CompoundModule(model_c.modules)
shoot_compound.manifold.fill(model_c.init_manifold.copy())
shot, controls = dm.shooting.shoot_euler(dm.hamiltonian.Hamiltonian(shoot_compound), 10)

In [None]:
out_shot = []
out_controls = []

for i in range(len(shot)):
    out_shot.append((shot[i][0].gd.detach().view(-1, 2).numpy(),
                     (shot[i][1].gd[0].detach().view(-1, 2).numpy(),
                     shot[i][1].gd[1].detach().view(-1, 2).numpy()),
                     shot[i][2].gd.detach().view(-1, 2).numpy(),
                     shot[i][3].gd.detach().view(-1, 2).numpy()))

for i in range(len(controls)):
    out_controls.append((controls[i][0].detach().numpy(),
                         controls[i][1].detach().numpy(),
                         controls[i][2].detach().numpy(),
                         controls[i][3].detach().numpy(),
                         controls[i][4].detach().numpy()))


In [None]:
# We save results here
result_output = {}
# Initial data
result_output['source_dots'] = source.numpy()
result_output['source_curve'] = source_curve.numpy()
result_output['target_dots'] = target.numpy()
result_output['target_curve'] = target_curve.numpy()

# When fitting C
result_output['fit_c'] = C.detach().numpy()
result_output['fit_c_cost'] = costs
result_output['fit_c_curve'] = fit_curve.numpy()
result_output['fit_c_dots'] = imp_points.numpy()
result_output['fit_c_shot'] = out_shot_c
result_output['fit_c_controls'] = out_controls_c
result_output['fit_c_grid'] = def_grids_c

# When using C
result_output['fit_shot'] = out_shot
result_output['fit_controls'] = out_controls
result_output['fit_cost'] = costs_c
result_output['fit_grid'] = def_grids

import pickle
pickle.dump(result_output, open("basipetal_data.pkl", 'wb'))

In [None]:
# We apply some deformation on the leaf shape
# Some additive gaussian noise
torch.manual_seed(1337)
target_curve_deformed = target_curve.clone() + 0.8*torch.randn_like(target_curve)
# Small translation
target_curve_deformed = target_curve_deformed + torch.tensor([5., 5.]).repeat([target_curve.shape[0], 1])
# Small rotation
r = dm.usefulfunctions.rot2d(math.pi/16.)
for i in range(target_curve_deformed.shape[0]):
    target_curve_deformed[i] = torch.mm(target_curve_deformed[i].unsqueeze(0), r)

aabb_deformed = dm.usefulfunctions.AABB.build_from_points(target_curve_deformed)

In [None]:
plt.axis('equal')
plt.plot(target_curve_deformed[:, 0].numpy(), target_curve_deformed[:, 1].numpy())
plt.show()

In [None]:
implicit0_s = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, pts_implicit0_s.shape[0], gd=pts_implicit0_s.view(-1).requires_grad_()), 1., 0.001, 5.)
implicit0_m = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, pts_implicit0_m.shape[0], gd=pts_implicit0_m.view(-1).requires_grad_()), m_scale, 0.001, 5.)
implicit1_c = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, source_c.shape[0], gd=(source_c.view(-1).requires_grad_(), R_c.view(-1).requires_grad_())), C_c, sigma1, nu1, coeff1)
global_trans_c = dm.implicitmodules.ImplicitModule0.build_from_points(2, 1, 1000., 0.001, 0.001, gd=torch.zeros(2, requires_grad=True))

In [None]:
# Setting up the model and start the fitting loop
model_deformed = dm.models.ModelCompoundWithPointsRegistration((source_curve, torch.ones(source_curve.shape[0])), [implicit1_c, implicit0_s, implicit0_m, global_trans_c], [True, True, True, True], dm.attachement.VarifoldAttachement([10., 50.]))
costs_deformed = model_deformed.fit((target_curve_deformed, torch.ones(target_curve_deformed.shape[0])), max_iter=500, l=1., lr=0.1, log_interval=1)

In [None]:
# Some plots
%matplotlib qt5

fit_curve = model_deformed.modules[0].manifold.gd.detach().view(-1, 2)

aabb = dm.usefulfunctions.AABB.build_from_points(torch.cat([target_curve, fit_curve]))
def_grids_deformed = model_deformed.compute_deformation_grid(torch.tensor([aabb_deformed.xmin, aabb_deformed.ymin]),
                                                             torch.tensor([aabb_deformed.width, aabb_deformed.height]),
                                                             torch.tensor([32, 32]))

ax_deformed = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
dm.usefulfunctions.plot_grid(ax_deformed, def_grids_deformed[-1][0].numpy(), def_grids_deformed[-1][1].numpy(), color='C0')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy(), '-')
plt.plot(target_curve_deformed[:, 0].numpy(), target_curve_deformed[:, 1].numpy(), '-')
plt.plot(fit_curve[:, 0].numpy(), fit_curve[:, 1].numpy(), '--')

plt.show()

In [None]:
shoot_deformed = dm.deformationmodules.CompoundModule(model_deformed.modules)
shoot_deformed.manifold.fill(model_deformed.init_manifold.copy())
shot_deformed, controls_deformed = dm.shooting.shoot_euler(dm.hamiltonian.Hamiltonian(shoot_deformed), 10)

In [None]:
out_shot_deformed = []
out_controls_deformed = []

for i in range(len(shot_deformed)):
    out_shot_deformed.append((shot_deformed[i][0].gd.detach().view(-1, 2).numpy(),
                              (shot_deformed[i][1].gd[0].detach().view(-1, 2).numpy(),
                               shot_deformed[i][1].gd[1].detach().view(-1, 2).numpy()),
                              shot_deformed[i][2].gd.detach().view(-1, 2).numpy(),
                              shot_deformed[i][3].gd.detach().view(-1, 2).numpy()))

for i in range(len(controls_deformed)):
    out_controls_deformed.append((controls_deformed[i][0].detach().numpy(),
                                  controls_deformed[i][1].detach().numpy(),
                                  controls_deformed[i][2].detach().numpy(),
                                  controls_deformed[i][3].detach().numpy(),
                                  controls_deformed[i][4].detach().numpy()))


In [None]:
deformed_output = {}

deformed_output['deformed_target'] = target_curve_deformed.numpy()
deformed_output['deformed_shot'] = out_shot_deformed
deformed_output['deformed_controls'] = out_controls_deformed
deformed_output['deformed_costs'] = costs_deformed
deformed_output['deformed_grid'] = def_grids_deformed

pickle.dump(deformed_output, open("basipetal_deformed_data.pkl", 'wb'))