In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.spatial import ConvexHull

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)
torch.set_printoptions(precision=12)

In [2]:
data = pickle.load(open("../../data/basipetal.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 38.
height_target = 100.

source = torch.tensor(data['source_d'], dtype=torch.get_default_dtype())
target = torch.tensor(data['target_d'], dtype=torch.get_default_dtype())
source_curve = torch.tensor(data['source_c'], dtype=torch.get_default_dtype())
target_curve = torch.tensor(data['target_c'], dtype=torch.get_default_dtype())

smin, smax = torch.min(source_curve[:, 1]), torch.max(source_curve[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source_curve[:, 0]))
source_curve[:, 1] = Dy - sscale * (source_curve[:, 1] - smax)
source_curve[:, 0] = Dx + sscale * (source_curve[:, 0] - torch.mean(source_curve[:, 0]))

tmin, tmax = torch.min(target_curve[:, 1]), torch.max(target_curve[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target_curve[:, 0]))
target_curve[:, 1] = - tscale * (target_curve[:, 1] - tmax)
target_curve[:, 0] = tscale * (target_curve[:, 0] - torch.mean(target_curve[:, 0]))

source_curve_fit = source_curve.clone()
source_curve_fit = source_curve_fit[source_curve_fit[:, 0] <= 0]
target_curve_fit = target_curve.clone()
target_curve_fit = target_curve_fit[target_curve_fit[:, 0] <= 0]

hull_fit = ConvexHull(source_curve_fit)
source_curve_fit_convex = source_curve_fit[hull_fit.vertices]
hull = ConvexHull(source_curve)
source_curve_convex = source_curve[hull.vertices]

aabb = dm.Utilities.AABB.build_from_points(torch.cat([target_curve_fit, source_curve_fit]))
aabb_source_fit = dm.Utilities.AABB.build_from_points(source_curve_fit)
aabb_source = dm.Utilities.AABB.build_from_points(source_curve)
aabb.squared()

<implicitmodules.torch.Utilities.aabb.AABB at 0x7f60593961d0>

In [3]:
pts_implicit1_step = 1.5
pts_implicit1_x, pts_implicit1_y = torch.meshgrid([
    torch.arange(0., aabb_source_fit.xmin-pts_implicit1_step, step=-pts_implicit1_step),
    torch.arange(aabb_source_fit.ymin, aabb_source_fit.ymax+pts_implicit1_step, step=pts_implicit1_step)])

pts_implicit1 = dm.Utilities.grid2vec(pts_implicit1_x, pts_implicit1_y)
pts_implicit1_mask = dm.Utilities.area_shape(pts_implicit1, shape=source_curve_fit)
pts_implicit1 = pts_implicit1[pts_implicit1_mask]

In [4]:
%matplotlib qt5

plt.axis('equal')

plt.subplot(1, 2, 1)
plt.title("Source")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(source_curve_fit[:, 0].numpy(), source_curve_fit[:, 1].numpy(), '-')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), 'x')
plt.plot(pts_implicit1[:, 0].numpy(), pts_implicit1[:, 1].numpy(), '.')
plt.axis('equal')

plt.subplot(1, 2, 2)
plt.title("Target")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(target_curve_fit[:, 0].numpy(), target_curve_fit[:, 1].numpy(), '-')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
plt.axis('equal')

plt.show()

In [5]:
sigma1 = 25.
nu1 = 0.001
coeff_global_trans = 1.
coeff1 = 1.
C = torch.ones(pts_implicit1.shape[0], 2, 1, requires_grad=True)
th = 0. * math.pi * torch.ones(pts_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(t) for t in th])
implicit1 = dm.DeformationModules.ImplicitModule1(2, pts_implicit1.shape[0], sigma1, C, nu1, coeff1, gd=(pts_implicit1.clone().requires_grad_(), R.clone().requires_grad_()))
global_trans = dm.DeformationModules.GlobalTranslation(2, coeff=coeff_global_trans)

In [6]:
# Setting up the model and start the fitting loop
model = dm.Models.ModelPointsRegistration(
    [source_curve_fit.clone(), source.clone()],
    [implicit1, global_trans],
    [dm.Attachment.VarifoldAttachment(2, [10., 50.], 0.1),
     dm.Attachment.EuclideanPointwiseDistanceAttachment(50.)],
    other_parameters=[('C', [implicit1.C])], lam=1000.)

In [7]:
fitter = dm.Models.ModelFittingScipy(model)

costs = fitter.fit([target_curve_fit, target], 500, options={}, disp=True)

Initial energy = 32710868.000
Time: 2.167531627928838
Iteration: 1 
Total energy = 24604144.0 
Attach cost = 24604070.0 
Deformation cost = 74.99101257324219
Time: 3.2619594500865787
Iteration: 2 
Total energy = 24438188.0 
Attach cost = 24438112.0 
Deformation cost = 75.67955780029297
Time: 5.426634959876537
Iteration: 3 
Total energy = 24265428.0 
Attach cost = 24265346.0 
Deformation cost = 82.13809967041016
Time: 6.537988371914253
Iteration: 4 
Total energy = 23843444.0 
Attach cost = 23843344.0 
Deformation cost = 100.704833984375
Time: 7.634637898067012
Iteration: 5 
Total energy = 23182686.0 
Attach cost = 23182536.0 
Deformation cost = 150.2554473876953
Time: 9.77721286797896
Iteration: 6 
Total energy = 20299230.0 
Attach cost = 20299060.0 
Deformation cost = 169.6476287841797
Time: 11.96157661289908
Iteration: 7 
Total energy = 19630824.0 
Attach cost = 19630628.0 
Deformation cost = 196.22303771972656
Time: 13.103087462950498
Iteration: 8 
Total energy = 18427930.0 
Attach c

Time: 55.381578692933545
Iteration: 41 
Total energy = 5000908.0 
Attach cost = 4998341.0 
Deformation cost = 2567.2490234375
Time: 56.509433928877115
Iteration: 42 
Total energy = 4862871.5 
Attach cost = 4860115.5 
Deformation cost = 2756.152587890625
Time: 57.64916377305053
Iteration: 43 
Total energy = 4754530.5 
Attach cost = 4751766.0 
Deformation cost = 2764.581787109375
Time: 58.74707826692611
Iteration: 44 
Total energy = 4649104.0 
Attach cost = 4646306.5 
Deformation cost = 2797.302001953125
Time: 59.86328904796392
Iteration: 45 
Total energy = 4510338.0 
Attach cost = 4507436.0 
Deformation cost = 2901.8955078125
Time: 62.11296476610005
Iteration: 46 
Total energy = 4437671.0 
Attach cost = 4434780.5 
Deformation cost = 2890.505126953125
Time: 64.29031814588234
Iteration: 47 
Total energy = 4386716.5 
Attach cost = 4383845.5 
Deformation cost = 2870.796630859375
Time: 65.43983035790734
Iteration: 48 
Total energy = 4283665.5 
Attach cost = 4280809.0 
Deformation cost = 2856

Time: 108.02065915009007
Iteration: 81 
Total energy = 2496909.5 
Attach cost = 2489846.0 
Deformation cost = 7063.59033203125
Time: 109.15374687104486
Iteration: 82 
Total energy = 2487376.0 
Attach cost = 2480325.25 
Deformation cost = 7050.8251953125
Time: 110.24666017806157
Iteration: 83 
Total energy = 2464480.0 
Attach cost = 2457545.75 
Deformation cost = 6934.130859375
Time: 111.40065663005225
Iteration: 84 
Total energy = 2430044.0 
Attach cost = 2423162.25 
Deformation cost = 6881.81689453125
Time: 112.58642869698815
Iteration: 85 
Total energy = 2373423.25 
Attach cost = 2366512.5 
Deformation cost = 6910.744140625
Time: 114.7995175619144
Iteration: 86 
Total energy = 2343715.25 
Attach cost = 2336933.75 
Deformation cost = 6781.44775390625
Time: 115.97918318188749
Iteration: 87 
Total energy = 2318512.5 
Attach cost = 2311804.25 
Deformation cost = 6708.13916015625
Time: 117.12387912604026
Iteration: 88 
Total energy = 2289173.25 
Attach cost = 2282499.75 
Deformation cost 

Time: 171.56888048001565
Iteration: 121 
Total energy = 1356356.0 
Attach cost = 1350832.625 
Deformation cost = 5523.37060546875
Time: 172.82805571402423
Iteration: 122 
Total energy = 1354631.0 
Attach cost = 1349093.875 
Deformation cost = 5537.07763671875
Time: 174.03556135389954
Iteration: 123 
Total energy = 1354383.25 
Attach cost = 1348832.0 
Deformation cost = 5551.279296875
Time: 175.1957543939352
Iteration: 124 
Total energy = 1353712.75 
Attach cost = 1348152.5 
Deformation cost = 5560.21484375
Time: 179.09321070997976
Iteration: 125 
Total energy = 1353692.625 
Attach cost = 1348131.5 
Deformation cost = 5561.15966796875
Time: 180.264241179917
Iteration: 126 
Total energy = 1353342.625 
Attach cost = 1347773.25 
Deformation cost = 5569.33203125
Time: 181.44446055800654
Iteration: 127 
Total energy = 1353194.75 
Attach cost = 1347599.5 
Deformation cost = 5595.27587890625
Time: 182.66164925391786
Iteration: 128 
Total energy = 1352989.875 
Attach cost = 1347400.25 
Deformat

Time: 226.7019805109594
Iteration: 160 
Total energy = 1211851.375 
Attach cost = 1206307.375 
Deformation cost = 5543.978515625
Time: 227.8319917989429
Iteration: 161 
Total energy = 1209969.875 
Attach cost = 1204432.25 
Deformation cost = 5537.68115234375
Time: 228.9744227488991
Iteration: 162 
Total energy = 1207927.25 
Attach cost = 1202369.25 
Deformation cost = 5557.94091796875
Time: 231.16103245108388
Iteration: 163 
Total energy = 1206630.875 
Attach cost = 1201058.25 
Deformation cost = 5572.64794921875
Time: 232.3322250940837
Iteration: 164 
Total energy = 1205831.625 
Attach cost = 1200221.5 
Deformation cost = 5610.13623046875
Time: 233.5084349780809
Iteration: 165 
Total energy = 1205267.875 
Attach cost = 1199689.75 
Deformation cost = 5578.18310546875
Time: 234.64122240105644
Iteration: 166 
Total energy = 1203955.5 
Attach cost = 1198375.75 
Deformation cost = 5579.8046875
Time: 235.8008994620759
Iteration: 167 
Total energy = 1202976.5 
Attach cost = 1197376.625 
Defo

Time: 291.75600091391243
Iteration: 200 
Total energy = 1106338.75 
Attach cost = 1100374.125 
Deformation cost = 5964.68115234375
Time: 295.32606382109225
Iteration: 201 
Total energy = 1106160.25 
Attach cost = 1100196.375 
Deformation cost = 5963.87109375
Time: 297.757865464082
Iteration: 202 
Total energy = 1105027.25 
Attach cost = 1099070.375 
Deformation cost = 5956.8759765625
Time: 299.06941599887796
Iteration: 203 
Total energy = 1104620.5 
Attach cost = 1098660.25 
Deformation cost = 5960.30078125
Time: 300.3116829050705
Iteration: 204 
Total energy = 1104584.125 
Attach cost = 1098613.375 
Deformation cost = 5970.69140625
Time: 301.4963630719576
Iteration: 205 
Total energy = 1104515.625 
Attach cost = 1098547.125 
Deformation cost = 5968.4580078125
Time: 302.7850733860396
Iteration: 206 
Total energy = 1104476.5 
Attach cost = 1098508.125 
Deformation cost = 5968.31640625
Time: 303.97873840504326
Iteration: 207 
Total energy = 1104456.375 
Attach cost = 1098484.375 
Deforma

Time: 359.08490674290806
Iteration: 240 
Total energy = 1058573.5 
Attach cost = 1052495.5 
Deformation cost = 6077.9697265625
Time: 361.47817961894907
Iteration: 241 
Total energy = 1058343.75 
Attach cost = 1052271.375 
Deformation cost = 6072.33544921875
Time: 363.9557052659802
Iteration: 242 
Total energy = 1057808.875 
Attach cost = 1051740.625 
Deformation cost = 6068.23193359375
Time: 366.7013792230282
Iteration: 243 
Total energy = 1057526.5 
Attach cost = 1051460.875 
Deformation cost = 6065.59228515625
Time: 369.22889096988365
Iteration: 244 
Total energy = 1057305.5 
Attach cost = 1051239.375 
Deformation cost = 6066.1591796875
Time: 371.88846493000165
Iteration: 245 
Total energy = 1057209.125 
Attach cost = 1051140.375 
Deformation cost = 6068.7880859375
Time: 373.1673039998859
Iteration: 246 
Total energy = 1057182.5 
Attach cost = 1051107.375 
Deformation cost = 6075.0830078125
Time: 374.5100028880406
Iteration: 247 
Total energy = 1056962.0 
Attach cost = 1050888.625 
D

In [8]:
%matplotlib qt5
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.plot(range(len(costs)), costs)
plt.show()

In [10]:
%matplotlib qt5
C = model.modules[2].C.detach()
R = model.modules[2].manifold.gd[1].detach()

points = model.modules[1].manifold.gd.detach().view(-1, 2)
out_curve = model.modules[0].manifold.gd.detach().view(-1, 2)
ax = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.')
plt.plot(points[:, 0].numpy(), points[:, 1].numpy(), 'o')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
plt.plot(target_curve_fit[:, 0].numpy(), target_curve_fit[:, 1].numpy(), '--')
plt.plot(out_curve[:, 0].numpy(), out_curve[:, 1].numpy(), '--')
for i in range(points.shape[0]):
   plt.plot([target[i, 0], points[i, 0]], [target[i, 1], points[i, 1]], '-')

def_grids_c = model.compute_deformation_grid([aabb_source.xmin, aabb_source.ymin],
                                             [aabb_source.width/2, aabb_source.height],
                                             [8, 16])

def_grid_c_x, def_grid_c_y = def_grids_c[0], def_grids_c[1]
dm.Utilities.plot_grid(ax, def_grid_c_x.numpy(), def_grid_c_y.numpy(), color='C0')

# dm.Utilities.plot_C_arrow(ax, source, C, R=R, scale=1., zorder=3, mutation_scale=10)
# dm.Utilities.plot_C_ellipse(ax, pts_implicit1, C, R=R, scale=1., color='black')

plt.xlabel("$x$")
plt.ylabel("$y$")

plt.show()