In [1]:
%load_ext autoreload
%autoreload 2

import sys
import math
import pickle
import random

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pygmsh
import meshio
import pymesh

sys.path.append("../../")

import implicitmodules.torch as dm

dm.Utilities.set_compute_backend('torch')

In [2]:
data_folder = "../../../../data/ferret/simple/"
source_mesh = meshio.read(data_folder+"F06_P4_simple.stl")
target_mesh = meshio.read(data_folder+"F10_P8_simple.stl")

source_points = torch.tensor(source_mesh.points, dtype=torch.get_default_dtype())
target_points = torch.tensor(target_mesh.points, dtype=torch.get_default_dtype())
source_triangles = torch.tensor(source_mesh.cells_dict['triangle'], dtype=torch.long)
target_triangles = torch.tensor(target_mesh.cells_dict['triangle'], dtype=torch.long)

In [3]:
source_points = source_points - torch.mean(source_points, dim=0)

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d', proj_type='ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color=(0., 1., 0., 0.2))
dm.Utilities.set_aspect_equal_3d(ax)
ax.plot_trisurf(target_points[:, 0].numpy(), target_points[:, 1].numpy(), target_points[:, 2].numpy(), triangles=target_triangles, color=(1., 0., 0., 1.))
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [4]:
area = lambda x, shape_ext, shape_int: dm.Utilities.area_convex_shape(x, shape=shape_ext) & ~dm.Utilities.area_convex_shape(x, shape=shape_int)

In [5]:
aabb_source = dm.Utilities.AABB.build_from_points(source_points)
implicit1_density = 0.6

implicit1_points = dm.Utilities.fill_area_uniform_density(area, aabb_source, implicit1_density, shape_ext=1.1*source_points, shape_int=0.75*source_points)
implicit1_r = torch.empty(implicit1_points.shape[0], 3, 3)
implicit1_c = torch.empty(implicit1_points.shape[0], 3, 1)

growth_constants = torch.ones(implicit1_points.shape[0], 2, 1, requires_grad=True)
angles = torch.zeros(implicit1_points.shape[0], 3, requires_grad=True)

print(implicit1_points.shape[0])

267


In [6]:
%matplotlib qt5
ax = plt.subplot(projection='3d', proj_type='ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color=(0., 1., 0., 0.3))
plt.plot(implicit1_points[:, 0], implicit1_points[:, 1], implicit1_points[:, 2], '.')
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [6]:
sigma = 2.*implicit1_density**(1/3)
print(sigma)

global_translation = dm.DeformationModules.GlobalTranslation(3)

implicit1 = dm.DeformationModules.ImplicitModule1(3, implicit1_points.shape[0], sigma, implicit1_c, nu=0.1, gd=(implicit1_points, implicit1_r))

1.6868653306034984


In [7]:
def compute_basis(angles):
    rot_x = dm.Utilities.rot3d_x_vec(angles[:, 0])
    rot_y = dm.Utilities.rot3d_y_vec(angles[:, 1])
    rot_z = dm.Utilities.rot3d_z_vec(angles[:, 2])
    return torch.einsum('nik, nkl, nlj->nij', rot_z, rot_y, rot_x)

In [8]:
def precompute(init_manifold, modules, parameters):
    init_manifold[2].gd = (init_manifold[2].gd[0], compute_basis(parameters['growth']['params'][0]))

    modules[2].C = torch.cat([parameters['growth']['params'][1], torch.zeros(implicit1_points.shape[0], 1, 1)], dim=1)

In [9]:
deformable_source = dm.Models.DeformableMesh(source_points, source_triangles)
deformable_target = dm.Models.DeformableMesh(target_points, target_triangles)

sigmas_varifold = [0.5, 3.]
attachment = dm.Attachment.VarifoldAttachment(3, sigmas_varifold)

model = dm.Models.RegistrationModel(deformable_source, [global_translation, implicit1], [attachment], lam=10., precompute_callback=precompute, other_parameters={'growth': {'params': [angles, growth_constants]}})

In [10]:
shoot_solver = 'euler'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

fitter.fit(deformable_target, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

Starting optimization with method torch LBFGS
Initial cost={'deformation': tensor(0.), 'attach': tensor(215398.5625)}


Time: 13.982005583005957
Iteration: 0
Costs
deformation=0.0
attach=215398.5625
Total cost=215398.5625


Time: 27.813811864936724
Iteration: 1
Costs
deformation=0.0
attach=215398.5625
Total cost=215398.5625
Optimisation process exited with message: Convergence achieved.
Final cost=215398.5625
Model evaluation count=2
Time elapsed = 27.814029834931716


In [None]:
import time

intermediates = {}
start = time.perf_counter()
with torch.autograd.no_grad():
    deformed = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0].detach()
print("Elapsed={elapsed}".format(elapsed=time.perf_counter()-start))

basis = compute_basis(angles.detach())

In [None]:
print(growth_constants)
print(angles)

In [None]:
meshio.write_points_cells("deformed_brain.ply", deformed, [('triangle', source_triangles)])