In [None]:
%load_ext autoreload
%autoreload 2

import sys
import math
import pickle
import random

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pymesh

sys.path.append("../../")

import implicitmodules.torch as dm

In [None]:
source_boudin = dm.Utilities.generate_boudin(0.8, 0.5, 1., 0.3, 40, 20, 10)
target_boudin = dm.Utilities.generate_boudin(0.75, 0.6, 0.8, 0.1, 40, 20, 10) + torch.tensor([0.15, 0.])

In [None]:
%matplotlib qt5
plt.plot(source_boudin[:, 0].numpy(), source_boudin[:, 1].numpy())
plt.plot(target_boudin[:, 0].numpy(), target_boudin[:, 1].numpy())
plt.axis('equal')
plt.show()

In [None]:
aabb = dm.Utilities.AABB.build_from_points(source_boudin).scale(1.2)

implicit1_points = dm.Utilities.fill_area_uniform_density(dm.Utilities.area_shape, aabb, 50., shape=source_boudin)
implicit1_rot = dm.Utilities.rot2d(0.).repeat(implicit1_points.shape[0], 1, 1)

implicit1_c = torch.zeros(implicit1_points.shape[0], 2, 1)
implicit1_c[:, 0, 0] = 1.

angles = torch.zeros(implicit1_points.shape[0], requires_grad=True)

print(implicit1_points.shape)

In [None]:
%matplotlib qt5
plt.plot(source_boudin[:, 0].numpy(), source_boudin[:, 1].numpy())
plt.plot(implicit1_points[:, 0].numpy(), implicit1_points[:, 1].numpy(), '.')
plt.axis('equal')
plt.show()

In [None]:
sigma = 0.3

global_translation = dm.DeformationModules.GlobalTranslation(2)

implicit1 = dm.DeformationModules.ImplicitModule1(2, implicit1_points.shape[0], sigma, implicit1_c, nu=0.1, gd=(implicit1_points, implicit1_rot))

In [None]:
source_deformable = dm.Models.DeformablePoints(source_boudin)
target_deformable = dm.Models.DeformablePoints(target_boudin)

def precompute(init_manifold, modules, parameters):
    rot = dm.Utilities.rot2d_vec(parameters['angles']['params'][0])

    init_manifold[2].gd = (init_manifold[2].gd[0], rot)


sigmas_varifold = [0.1, 0.6]
attachment = dm.Attachment.VarifoldAttachment(2, sigmas_varifold)

model = dm.Models.RegistrationModel([source_deformable], [global_translation, implicit1], [attachment], lam=1000., precompute_callback=precompute, other_parameters={'angles': {'params': [angles]}})

In [None]:
shoot_solver = 'euler'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

fitter.fit(target_deformable, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

In [None]:
intermediates = {}
deformed = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0].detach()

In [None]:
basis = model.init_manifold[2].gd[1].detach()

In [None]:
def plot_2d_basis(points, basis, **kwords):
    plt.quiver(points[:, 0].numpy(), points[:, 1].numpy(), basis[:, 0, 0].numpy(), basis[:, 1, 0].numpy(), color='blue', headlength=0., headwidth=0., **kwords)
    plt.quiver(points[:, 0].numpy(), points[:, 1].numpy(), basis[:, 0, 1].numpy(), basis[:, 1, 1].numpy(), color='red', headlength=0.,headwidth=0., **kwords)


In [None]:
%matplotlib qt5
plt.plot(source_boudin[:, 0].numpy(), source_boudin[:, 1].numpy(), lw=0.4, color='black')
plt.plot(target_boudin[:, 0].numpy(), target_boudin[:, 1].numpy(), lw=1., color='green')
plt.plot(deformed[:, 0].numpy(), deformed[:, 1].numpy(), lw=1., color='blue')
plt.axis('equal')
plt.show()

In [None]:
%matplotlib qt5
plt.plot(source_boudin[:, 0].numpy(), source_boudin[:, 1].numpy(), lw=1., color='black')
plot_2d_basis(implicit1_points, basis, scale=30.)
plt.axis('equal')
plt.show()
