In [None]:
%reset
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
data = pickle.load(open("../../data/basipetal.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 38.
height_target = 100.

source = torch.tensor(data['source_d'])
target = torch.tensor(data['target_d'])
source_curve = torch.tensor(data['source_c'])
target_curve = torch.tensor(data['target_c'])

smin, smax = torch.min(source_curve[:, 1]), torch.max(source_curve[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source_curve[:, 0]))
source_curve[:, 1] = Dy - sscale * (source_curve[:, 1] - smax)
source_curve[:, 0] = Dx + sscale * (source_curve[:, 0] - torch.mean(source_curve[:, 0]))

tmin, tmax = torch.min(target_curve[:, 1]), torch.max(target_curve[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target_curve[:, 0]))
target_curve[:, 1] = - tscale * (target_curve[:, 1] - tmax)
target_curve[:, 0] = tscale * (target_curve[:, 0] - torch.mean(target_curve[:, 0]))

aabb = dm.usefulfunctions.AABB.build_from_points(torch.cat([target_curve, source_curve]))
aabb.squared()

In [None]:
%matplotlib qt5

plt.axis('equal')

plt.subplot(1, 2, 1)
plt.axis(aabb.get_list())
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy(), '-')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), 'x')

plt.subplot(1, 2, 2)
plt.axis(aabb.get_list())
plt.title('Target')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(target_curve[:, 0].numpy(), target_curve[:, 1].numpy(), '-')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')

plt.show()

In [None]:
sigma1 = 30.
nu1 = 0.001
coeff1 = 0.001
C = torch.ones(source.shape[0], 2, 1, requires_grad=True)
th = 0. * math.pi * torch.ones(source.shape[0])
R = torch.stack([dm.usefulfunctions.rot2d(t) for t in th])

implicit1 = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, source.shape[0], gd=(source.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C, sigma1, nu1, coeff1)
#global_trans = dm.implicitmodules.ImplicitModule0.build_from_points(2, 1, 10000., 0.001, 0.001, gd=torch.zeros(2, requires_grad=True))
global_trans = dm.globalmodules.GlobalTranslation(2, coeff=0.001)

In [None]:
# Setting up the model and start the fitting loop
model = dm.models.ModelCompoundWithPointsRegistration((source, torch.ones(source.shape[0])), [implicit1, global_trans], [True, True], dm.attachement.PointwiseDistanceAttachement(), parameters=[implicit1.C])
costs = model.fit((target, torch.ones(target.shape[0])), max_iter=2000, l=1., lr=0.05, log_interval=50)

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
%matplotlib qt5
from matplotlib.patches import Ellipse

C = model.modules[1].C
points = model.modules[1].manifold.gd[0].detach().view(-1, 2)
ax = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.')
plt.plot(points[:, 0].numpy(), points[:, 1].numpy(), 'o')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
for i in range(points.shape[0]):
    plt.plot([target[i, 0], points[i, 0]], [target[i, 1], points[i, 1]], '-')

epsix, epsiy = 0.0000001, 0.0000001
for i in range(source.shape[0]):
    C_i = C[i, :, 0]
    ell = Ellipse(xy=source[i], width=C_i[0] + epsix, height=C_i[1] + epsiy, angle=0.)
    ax.add_artist(ell)

plt.show()

In [None]:
source_mirrored = source.clone()
source_mirrored[:, 0] = -source_mirrored[:, 0]
source_c = torch.cat([source.clone(), source_mirrored])

C_c = torch.cat([C.detach(), C.detach()])
R_c = torch.cat([R, R])
print(C_c.shape)
print(source_c.shape)

implicit_c = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, source_c.shape[0], gd=(source_c.view(-1).requires_grad_(), R_c.view(-1).requires_grad_())), C_c, sigma1, nu1, coeff1)
global_trans_c = dm.implicitmodules.ImplicitModule0.build_from_points(2, 1, 1000., 0.001, 0.001, gd=torch.zeros(2, requires_grad=True))

In [None]:
# Setting up the model and start the fitting loop
model_c = dm.models.ModelCompoundWithPointsRegistration((source_curve, torch.ones(source_curve.shape[0])), [implicit_c, global_trans_c], [True, True], dm.attachement.VarifoldAttachement([10., 50.]))
costs_c = model_c.fit((target_curve, torch.ones(target_curve.shape[0])), max_iter=100, l=1., lr=0.1, log_interval=1)

In [None]:
plt.plot(range(len(costs_c)), costs_c)
plt.show()

In [None]:
# Some plots
%matplotlib qt5

fit_curve = model_c.modules[0].manifold.gd.detach().view(-1, 2)
imp_points = model_c.modules[1].manifold.gd[0].detach().view(-1, 2)

aabb = dm.usefulfunctions.AABB.build_from_points(torch.cat([target_curve, fit_curve]))

#plt.axis(aabb.get_list())
plt.axis('equal')
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(source_c[:, 0].numpy(), source_c[:, 1].numpy(), '.')
plt.plot(source_curve[:, 0].numpy(), source_curve[:, 1].numpy(), '-')
#plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'o')
plt.plot(target_curve[:, 0].numpy(), target_curve[:, 1].numpy(), '-')
plt.plot(fit_curve[:, 0].numpy(), fit_curve[:, 1].numpy(), '--')
plt.plot(imp_points[:, 0].numpy(), imp_points[:, 1].numpy(), 'x')

plt.show()
