In [None]:
%reset
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
data = pickle.load(open("../../data/basipetal_dots.pkl", 'rb'))
data_curve = pickle.load(open("../../data/basipetal_curve.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 38.
height_target = 100.

source = torch.tensor(data['source'])
target = torch.tensor(data['target'])

smin, smax = torch.min(source[:, 1]), torch.max(source[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source[:, 0]))

tmin, tmax = torch.min(target[:, 1]), torch.max(target[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target[:, 0]))

pos_implicit1 = source

In [None]:
# Some plots
%matplotlib qt5
plt.axis('equal')
#plt.subplot(1, 2, 1)
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')

#plt.subplot(1, 2, 2)
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')

plt.show()

In [None]:
sigma1 = 30.
nu1 = 0.001
coeff1 = 0.001
C = torch.ones(pos_implicit1.shape[0], 2, 1, requires_grad=True)
# K, L = 10, height_source
# a, b = -2 / L ** 3, 3 / L ** 2
# C[:, 1, 0] = (K * (a * (L - pos_implicit1[:, 1] + Dy) ** 3  + b * (L - pos_implicit1[:, 1] + Dy) ** 2))
# C[:, 0, 0] = 1. * C[:, 1, 0]
th = 0. * math.pi * torch.ones(pos_implicit1.shape[0])
R = torch.stack([dm.usefulfunctions.rot2d(t) for t in th])

implicit1 = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, pos_implicit1.shape[0], gd=(pos_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C, sigma1, nu1, coeff1)
global_trans = dm.implicitmodules.ImplicitModule0.build_and_fill(2, 1, 10000., 0.001, 0.001, gd=torch.zeros(2, requires_grad=True))

In [None]:
# Setting up the model and start the fitting loop
model = dm.models.ModelCompoundWithPointsRegistration((pos_implicit1, torch.ones(pos_implicit1.shape[0])), [implicit1, global_trans], [True, True], dm.attachement.PointwiseDistanceAttachement(), parameters=[implicit1.C])

In [None]:
costs = model.fit((target, torch.ones(target.shape[0])), max_iter=1000, l=1., lr=0.08, log_interval=1)

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
%matplotlib qt5
from matplotlib.patches import Ellipse

C = model.modules[1].C
points = model.modules[1].manifold.gd[0].detach().view(-1, 2)
ax = plt.subplot(111, aspect='equal')
plt.axis('equal')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.')
plt.plot(points[:, 0].numpy(), points[:, 1].numpy(), 'o')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), 'x')
for i in range(points.shape[0]):
    plt.plot([target[i, 0], points[i, 0]], [target[i, 1], points[i, 1]], '-')

epsix, epsiy = 0.0000001, 0.0000001
for i in range(source.shape[0]):
    C_i = C[i, :, 0]
    ell = Ellipse(xy=source[i], width=C_i[0] + epsix, height=C_i[1] + epsiy, angle=0.)
    ax.add_artist(ell)

plt.show()

In [None]:
source_curve = torch.tensor(curve['source'])
target_curve = torch.tensor(curve['target'])


Dx = 0.
Dy = 0.
height_source_curve = 38.
height_target_curve = 100.

smin_curve, smax_curve = torch.min(source_curve[:, 1]), torch.max(source_curve[:, 1])
sscale_curve = height_source_curve / (smax_curve - smin_curve)
source_curve[:, 1] = Dy - sscale_curve * (source_curve[:, 1] - smax_curve)
source_curve[:, 0] = Dx + sscale_curve * (source_curve[:, 0] - torch.mean(source_curve[:, 0])) + torch.mean(source[:, 0])

tmin_curve, tmax_curve = torch.min(target_curve[:, 1]), torch.max(target_curve[:, 1])
tscale_curve = height_target_curve / (tmax_curve - tmin_curve)
target_curve[:, 1] = tscale_curve * (target_curve[:, 1] - tmax_curve)
target_curve[:, 0] = tscale_curve * (target_curve[:, 0] - torch.mean(target_curve[:, 0]))

pos_source_curve = source_curve[source_curve[:, 2] == 2, 0:2]
pos_target_curve = - target_curve[target_curve[:, 2] == 2, 0:2]

aabb = dm.usefulfunctions.AABB.build_from_points(pos_target)
aabb.squared()

In [None]:
# # Some plots
# %matplotlib qt5

# plt.axis('equal')

# plt.subplot(2, 2, 1)
# plt.axis(aabb.get_list())
# plt.title('Source')
# plt.xlabel('x')
# plt.ylabel('y')
# plt.plot(pos_source_curve[:, 0].numpy(), pos_source_curve[:, 1].numpy(), '-')
# plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), 'x')

# plt.subplot(2, 2, 2)
# plt.axis(aabb.get_list())
# plt.title('Target')
# plt.xlabel('x')
# plt.ylabel('y')
# plt.plot(pos_target_curve[:, 0].numpy(), pos_target_curve[:, 1].numpy(), '-')

# plt.subplot(2, 2, 3)
# plt.imshow(data_source[0])

# plt.subplot(2, 2, 4)
# plt.imshow(data_target[0])

# plt.show()