In [None]:
%load_ext autoreload
%autoreload 2

import sys
import math
import pickle
import random

import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pygmsh
import meshio
import pymesh

sys.path.append("../../")

import implicitmodules.torch as dm

dm.Utilities.set_compute_backend('keops')

In [None]:
data_folder = "/home/leander/data/"
source_mesh = meshio.read(data_folder+"armadillo.ply")
target_mesh = meshio.read(data_folder+"armadillo_deformed.ply")

source_points = torch.tensor(source_mesh.points, dtype=torch.get_default_dtype())
target_points = torch.tensor(target_mesh.points, dtype=torch.get_default_dtype())
source_triangles = torch.tensor(source_mesh.cells_dict['triangle'].astype(np.int32), dtype=torch.long)
target_triangles = torch.tensor(target_mesh.cells_dict['triangle'].astype(np.int32), dtype=torch.long)

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d', proj_type='ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color=(0., 1., 0., 0.2))
ax.plot_trisurf(target_points[:, 0].numpy(), target_points[:, 1].numpy(), target_points[:, 2].numpy(), triangles=target_triangles, color=(1., 0., 0., 1.))
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [None]:
area_mesh = meshio.read(data_folder+"armadillo_deformed_arm.ply")
area_mesh_points = torch.tensor(area_mesh.points, dtype=torch.get_default_dtype())+ torch.tensor([0., 10., 0.])
area_points, area_hull = dm.Utilities.extract_convex_hull(area_mesh_points)

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d', proj_type='ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color=(0., 1., 0., 0.2))
ax.plot_trisurf(area_mesh_points.numpy()[:, 0], area_mesh_points.numpy()[:, 1], area_mesh_points.numpy()[:, 2], triangles=area_hull, color=(1., 0., 0., 1.))
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [None]:
implicit1_density = 0.05

implicit1_points = dm.Utilities.fill_area_uniform_density(dm.Utilities.area_convex_hull, dm.Utilities.AABB.build_from_points(area_mesh_points), implicit1_density, scatter=area_mesh_points.numpy())

implicit1_r = torch.empty(implicit1_points.shape[0], 3, 3)
implicit1_c = torch.empty(implicit1_points.shape[0], 3, 1)

growth_constants = torch.ones(implicit1_points.shape[0], 2, 1, requires_grad=True)
angles = torch.zeros(implicit1_points.shape[0], 3, requires_grad=True)

print(implicit1_points.shape)

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d', proj_type='ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color=(0., 1., 0., 0.2))
ax.plot_trisurf(area_mesh_points.numpy()[:, 0], area_mesh_points.numpy()[:, 1], area_mesh_points.numpy()[:, 2], triangles=area_hull, color=(1., 0., 0., 0.1))
plt.plot(implicit1_points.numpy()[:, 0], implicit1_points.numpy()[:, 1], implicit1_points.numpy()[:, 2], 'o')
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [None]:
sigma = 2.*implicit1_density**(1/3)
print(sigma)

global_translation = dm.DeformationModules.GlobalTranslation(3)

implicit1 = dm.DeformationModules.ImplicitModule1(3, implicit1_points.shape[0], sigma, implicit1_c, nu=2., gd=(implicit1_points, implicit1_r))

In [None]:
def compute_basis(angles):
    rot_x = dm.Utilities.rot3d_x_vec(angles[:, 0])
    rot_y = dm.Utilities.rot3d_y_vec(angles[:, 1])
    rot_z = dm.Utilities.rot3d_z_vec(angles[:, 2])
    return torch.einsum('nik, nkl, nlj->nij', rot_z, rot_y, rot_x)

In [None]:
def precompute(init_manifold, modules, parameters):
    init_manifold[2].gd = (init_manifold[2].gd[0], compute_basis(parameters['growth']['params'][0]))

    modules[2].C = torch.cat([parameters['growth']['params'][1], torch.zeros(implicit1_points.shape[0], 1, 1)], dim=1)

In [None]:
deformable_source = dm.Models.DeformableMesh(source_points, source_triangles)
deformable_target = dm.Models.DeformableMesh(target_points, target_triangles)

sigmas_varifold = [15.]
attachment = dm.Attachment.VarifoldAttachment(3, sigmas_varifold)

model = dm.Models.RegistrationModel(deformable_source, [global_translation, implicit1], [attachment], lam=10., precompute_callback=precompute, other_parameters={'growth': {'params': [angles, growth_constants]}})

In [None]:
shoot_solver = 'euler'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

fitter.fit(deformable_target, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

In [None]:
import time

intermediates = {}
start = time.perf_counter()
with torch.autograd.no_grad():
    deformed = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0].detach()
print("Elapsed={elapsed}".format(elapsed=time.perf_counter()-start))

basis = compute_basis(angles.detach())