In [1]:
%load_ext autoreload
%autoreload 2

import sys
import math
import pickle
import random

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pymesh

sys.path.append("../../")

import implicitmodules.torch as dm

In [2]:
seed = 1337
refinement_order = 2
sigma = 2.
sphere_count = 5
spheres = []
spheres_data = []

random.seed(seed)

for i in range(sphere_count):
    mesh = pymesh.generate_icosphere(0.5, [0., 0., 0.], refinement_order)

    # axis = random.randint(0, 2)
    axis = 0
    scale = random.gauss(0., sigma)
    scale_matrix = torch.eye(3)
    scale_matrix[axis, axis] = scale

    deformed_vertices = dm.Utilities.linear_transform(torch.tensor(mesh.vertices, dtype=torch.get_default_dtype()), scale_matrix)
    spheres.append(dm.Models.DeformableMesh(deformed_vertices, torch.tensor(mesh.faces, dtype=torch.long)))
    spheres_data.append({'axis': axis, 'scale': scale})

template_mesh = pymesh.generate_icosphere(0.5, [0., 0., 0.], refinement_order)
template = dm.Models.DeformableMesh(torch.tensor(template_mesh.vertices, dtype=torch.get_default_dtype()), torch.tensor(template_mesh.faces, dtype=torch.long))

with open("spheres.pickle", 'wb') as f:
    pickle.dump(spheres, f)

print("Vertices count={vertices_count}".format(vertices_count=template_mesh.vertices.shape[0]))

Vertices count=162


In [3]:
print(spheres_data)

[{'axis': 0, 'scale': -1.8235065865953255}, {'axis': 0, 'scale': -1.6645668007019474}, {'axis': 0, 'scale': -1.7666219715494316}, {'axis': 0, 'scale': 1.9824361786675635}, {'axis': 0, 'scale': 1.8849967465189708}]


In [4]:
target = spheres[-1]

In [5]:
sigma = 0.5
aabb = dm.Utilities.AABB.build_from_points(torch.tensor(template_mesh.vertices))

implicit1_points = aabb.fill_uniform_density(50)
implicit1_rot = torch.eye(3).repeat(implicit1_points.shape[0], 1, 1)

# implicit1_c = 0.2*torch.randn(implicit1_points.shape[0], 3, 1) + 1.
# implicit1_c.requires_grad_()

# implicit1_c = torch.zeros(implicit1_points.shape[0], 3, 3)
# implicit1_c[:, 0, 0] = 1.
# implicit1_c[:, 1, 1] = 1.
# implicit1_c[:, 2, 2] = 1.
implicit1_c = torch.zeros(implicit1_points.shape[0], 3, 1)
implicit1_c[:, 0, 0] = 1.

In [6]:
global_translation = dm.DeformationModules.GlobalTranslation(3)

implicit1 = dm.DeformationModules.ImplicitModule1(3, implicit1_points.shape[0], sigma, implicit1_c, nu=0.1, gd=(implicit1_points, implicit1_rot))

In [7]:
sigmas_varifold = [0.5]
attachment = dm.Attachment.VarifoldAttachment(3, sigmas_varifold)

atlas = dm.Models.RegistrationModel([template], [global_translation, implicit1], [attachment], lam=100.)

In [8]:
shoot_solver = 'euler'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(atlas, optimizer='torch_lbfgs')

fitter.fit(target, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

Starting optimization with method torch LBFGS
Initial cost={'deformation': tensor(0.), 'attach': tensor(69.9232)}
Time: 56.898127212014515
Iteration: 0
Costs
deformation=0.2017395794391632
attach=0.14252662658691406
Total cost=0.34426620602607727
Time: 73.22745961300097
Iteration: 1
Costs
deformation=0.20173963904380798
attach=0.14252662658691406
Total cost=0.34426626563072205
Optimisation process exited with message: Convergence achieved.
Final cost=0.34426626563072205
Model evaluation count=31
Time elapsed = 73.22765529499156


In [None]:
deformed = model.compute_deformed(shoot_solver, shoot_it)

In [None]:
%matplotlib qt5
ax = plt.subplot()
ax.plot_trisurf(

In [2]:
with open("../../data/deformed_sphere.pickle", 'rb') as f:
    spheres = pickle.load(f)

In [3]:
template_points = spheres[0]
target_points = spheres[2]
triangles = spheres[3]
deformed = spheres[1]

In [9]:
%matplotlib qt5
ax = plt.subplot(projection='3d')
ax.plot_trisurf(template_points[:, 0].numpy(), template_points[:, 1].numpy(), template_points[:, 2].numpy(), triangles=triangles)
#ax.plot_trisurf(target_points[:, 0].numpy(), target_points[:, 1].numpy(), target_points[:, 2].numpy(), triangles=triangles)
#ax.plot_trisurf(deformed[:, 0].numpy(), deformed[:, 1].numpy(), deformed[:, 2].numpy(), triangles=triangles)
#ax.plot(deformed[:, 0].numpy(), deformed[:, 1].numpy(), deformed[:, 2].numpy(), '.b')
ax.plot_trisurf(target_points[:, 0].numpy(), target_points[:, 1].numpy(), target_points[:, 2].numpy(), triangles=triangles, color= (0,1,0,0.2), edgecolor=(1, 1, 0, 0.2))
ax.plot_trisurf(deformed[:, 0].numpy(), deformed[:, 1].numpy(), deformed[:, 2].numpy(), triangles=triangles, color= (0,1,0,0.2), edgecolor=(0, 1, 0, 0.2))
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()