In [None]:
%load_ext autoreload
%autoreload 2

import sys
import math
import pickle
import random

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pymesh

sys.path.append("../../")

import implicitmodules.torch as dm

In [None]:
seed = 1337
refinement_order = 2
sigma = 2.
sphere_count = 5
spheres = []
spheres_data = []

random.seed(seed)

sphere_mesh = pymesh.generate_icosphere(0.5, [0., 0., 0.], refinement_order)
sphere_points = torch.tensor(sphere_mesh.vertices, dtype=torch.get_default_dtype())
sphere_triangles = torch.tensor(sphere_mesh.faces, dtype=torch.long)

# axis = random.randint(0, 2)
axis = 0
scale = random.gauss(0., sigma)
scale_matrix = torch.eye(3)
scale_matrix[axis, axis] = scale
rot_matrix = dm.Utilities.rot3d_z(math.pi/3.) @ dm.Utilities.rot3d_x(2*math.pi/3.)
trans_matrix = rot_matrix @ scale_matrix

source_points = sphere_points
source_triangles = sphere_triangles
target_points = dm.Utilities.linear_transform(sphere_points, trans_matrix)
target_triangles = sphere_triangles

source_deformable = dm.Models.DeformableMesh(source_points, source_triangles)
target_deformable = dm.Models.DeformableMesh(target_points, target_triangles)

print("Vertices count={vertices_count}".format(vertices_count=sphere_points.shape[0]))

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles)
ax.plot_trisurf(target_points[:, 0].numpy(), target_points[:, 1].numpy(), target_points[:, 2].numpy(), triangles=target_triangles, color= (0,1,0,0.2), edgecolor=(0, 1, 0, 0.2))
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [None]:
sigma = 0.3
points_density = 200
aabb = dm.Utilities.AABB.build_from_points(torch.tensor(sphere_mesh.vertices)).scale(1.2)

implicit1_points = aabb.fill_uniform_density(points_density)
implicit1_points = implicit1_points[torch.where(0.55 > torch.norm(implicit1_points, dim=1))]
print(implicit1_points.shape)
implicit1_rot = torch.eye(3).repeat(implicit1_points.shape[0], 1, 1)

implicit1_c = torch.zeros(implicit1_points.shape[0], 3, 1)
implicit1_c[:, 0, 0] = 1.

angles = torch.zeros(implicit1_points.shape[0], 3, requires_grad=True)

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d')
ax.set_proj_type('ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color= (1,1,0,0.1))
plt.plot(implicit1_points[:, 0].numpy(), implicit1_points[:, 1].numpy(), implicit1_points[:, 2].numpy(), '.')
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [None]:
global_translation = dm.DeformationModules.GlobalTranslation(3)

implicit1 = dm.DeformationModules.ImplicitModule1(3, implicit1_points.shape[0], sigma, implicit1_c, nu=0.1, gd=(implicit1_points, implicit1_rot))

In [None]:
def precompute(init_manifold, modules, parameters):
    rot_x = dm.Utilities.rot3d_x_vec(parameters['angles']['params'][0][:, 0])
    rot_y = dm.Utilities.rot3d_y_vec(parameters['angles']['params'][0][:, 1])
    rot_z = dm.Utilities.rot3d_z_vec(parameters['angles']['params'][0][:, 2])

    basis = torch.einsum('nik, nkl, nlj->nij', rot_z, rot_y, rot_x)
    init_manifold[2].gd = (init_manifold[2].gd[0], basis)


sigmas_varifold = [0.5]
attachment = dm.Attachment.VarifoldAttachment(3, sigmas_varifold)

model = dm.Models.RegistrationModel([source_deformable], [global_translation, implicit1], [attachment], lam=10000., precompute_callback=precompute, other_parameters={'angles': {'params': [angles]}})

In [None]:
shoot_solver = 'euler'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

fitter.fit(target_deformable, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

In [None]:
intermediates = {}
deformed = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0].detach()

In [None]:
print(model.init_manifold[2].gd[1])
basis = model.init_manifold[2].gd[1].detach()

In [None]:
orth = torch.einsum('nik, njk->nij', basis, basis)

In [None]:
print(angles)

In [None]:
print(intermediates['controls'])

In [None]:
print(basis[0])
print(basis[0, :, 0])
print(basis[0, :, 1])


In [None]:
alpha_true = rot_matrix[:, 0]
print(alpha_true)
n = torch.norm(basis[:, :, 0] - alpha_true.repeat(basis.shape[0], 1), dim=1)
#inside_mask = 0.5 > torch.norm(implicit1_points, dim=1)
indices = torch.where(0.51 >= torch.norm(implicit1_points, dim=1))[0].tolist()
#points_inside = torch.select(
inside_mask = torch.where(0.6 > torch.norm(implicit1_points, dim=1), torch.ones(implicit1_points.shape[0]), torch.zeros(implicit1_points.shape[0]))
print(torch.stack([n, inside_mask], dim=1))

In [None]:
def plot_basis(ax, points, basis, **kwords):
    ax.quiver(points[:, 0].numpy(), points[:, 1].numpy(), points[:, 2].numpy(), basis[:, 0, 0].numpy(), basis[:, 1, 0].numpy(), basis[:, 2, 0].numpy(), color='blue', **kwords)

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color= (1,1,0,0.1))
#ax.plot_trisurf(deformed[:, 0].numpy(), deformed[:, 1].numpy(), deformed[:, 2].numpy(), triangles=source_triangles, color= (0,1,0,0.1), edgecolor=(0, 1, 0, 0.1))

ax.quiver(implicit1_points[indices, 0].numpy(), implicit1_points[indices, 1].numpy(), implicit1_points[indices, 2].numpy(), basis[indices, 0, 0].numpy(), basis[indices, 1, 0].numpy(), basis[indices, 2, 0].numpy(), length=0.1, color='blue')

ax.quiver(implicit1_points[indices, 0].numpy(), implicit1_points[indices, 1].numpy(), implicit1_points[indices, 2].numpy(), torch.ones(len(indices)).numpy(), torch.zeros(len(indices)).numpy(), torch.zeros(len(indices)).numpy(), length=0.1, color='red')

ax.quiver(implicit1_points[indices, 0].numpy(), implicit1_points[indices, 1].numpy(), implicit1_points[indices, 2].numpy(), alpha_true[0]*torch.ones(len(indices)).numpy(), alpha_true[1]*torch.ones(len(indices)).numpy(), alpha_true[2]*torch.ones(len(indices)).numpy(), length=0.1, color='black')
#ax.quiver(implicit1_points[:, 0].numpy(), implicit1_points[:, 1].numpy(), implicit1_points[:, 2].numpy(), basis[:, 0, 1].numpy(), basis[:, 1, 1].numpy(), basis[:, 2, 1].numpy(), length=0.1, color='red')
#ax.quiver(implicit1_points[:, 0].numpy(), implicit1_points[:, 1].numpy(), implicit1_points[:, 2].numpy(), basis[:, 0, 2].numpy(), basis[:, 1, 2].numpy(), basis[:, 2, 2].numpy(), length=0.1, color='green')
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()