In [1]:
%load_ext autoreload
%autoreload 2

import sys
import math
import pickle
import random

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pygmsh
import meshio
import pymesh

sys.path.append("../../")

import implicitmodules.torch as dm

dm.Utilities.set_compute_backend('keops')
device = 'cuda:2'

In [2]:
data_folder = "/users/home/lacroixle/data/ferret/ferret/"
source_mesh = meshio.read(data_folder+"F06_P4.ply")
target_mesh = meshio.read(data_folder+"F10_P8.ply")

source_points = torch.tensor(source_mesh.points, dtype=torch.get_default_dtype())
target_points = torch.tensor(target_mesh.points, dtype=torch.get_default_dtype())
source_triangles = torch.tensor(source_mesh.cells_dict['triangle'], dtype=torch.long)
target_triangles = torch.tensor(target_mesh.cells_dict['triangle'], dtype=torch.long)

In [3]:
source_points = source_points - torch.mean(source_points, dim=0)

In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d', proj_type='ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color=(0., 1., 0., 0.2))
dm.Utilities.set_aspect_equal_3d(ax)
ax.plot_trisurf(target_points[:, 0].numpy(), target_points[:, 1].numpy(), target_points[:, 2].numpy(), triangles=target_triangles, color=(1., 0., 0., 1.))
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [4]:
area = lambda x, shape_ext, shape_int: dm.Utilities.area_convex_shape(x, shape=shape_ext) & ~dm.Utilities.area_convex_shape(x, shape=shape_int)

In [5]:
aabb_source = dm.Utilities.AABB.build_from_points(source_points)
implicit1_density = 2.5

implicit1_points = dm.Utilities.fill_area_uniform_density(area, aabb_source, implicit1_density, shape_ext=1.1*source_points, shape_int=0.75*source_points)
implicit1_growth = area(implicit1_points, 1.1*source_points, 0.9*source_points)
implicit1_rigid = area(implicit1_points, 0.9*source_points, 0.75*source_points)
implicit1_growth_count = implicit1_points[implicit1_growth].shape[0]
implicit1_rigid_count = implicit1_points[implicit1_rigid].shape[0]

implicit1_r = torch.empty(implicit1_points.shape[0], 3, 3)
implicit1_c = torch.empty(implicit1_points.shape[0], 3, 1)

growth_constants = torch.ones(implicit1_growth_count, 2, 1, requires_grad=True, device=device)
angles = torch.zeros(implicit1_growth_count, 3, requires_grad=True, device=device)

print(implicit1_points.shape[0])
print(implicit1_growth.shape)
print(implicit1_rigid.shape)

2146
torch.Size([2146])
torch.Size([2146])


In [None]:
%matplotlib qt5
ax = plt.subplot(projection='3d', proj_type='ortho')
ax.plot_trisurf(source_points[:, 0].numpy(), source_points[:, 1].numpy(), source_points[:, 2].numpy(), triangles=source_triangles, color=(0., 1., 0., 0.3))
plt.plot(implicit1_points[:, 0], implicit1_points[:, 1], implicit1_points[:, 2], '.')
dm.Utilities.set_aspect_equal_3d(ax)
plt.show()

In [6]:
sigma = 1.5/implicit1_density**(1/3)
print(sigma)

global_translation = dm.DeformationModules.GlobalTranslation(3)

implicit1 = dm.DeformationModules.ImplicitModule1(3, implicit1_points.shape[0], sigma, implicit1_c, nu=10., gd=(implicit1_points, implicit1_r))

1.105209449592116


In [7]:
def compute_basis(angles):
    rot_x = dm.Utilities.rot3d_x_vec(angles[:, 0])
    rot_y = dm.Utilities.rot3d_y_vec(angles[:, 1])
    rot_z = dm.Utilities.rot3d_z_vec(angles[:, 2])
    return torch.einsum('nik, nkl, nlj->nij', rot_z, rot_y, rot_x)

In [8]:
def precompute(init_manifold, modules, parameters):
    init_manifold[2].gd = (init_manifold[2].gd[0], torch.cat([compute_basis(parameters['growth']['params'][0]), torch.eye(3, device=device).repeat(implicit1_rigid_count, 1, 1)], dim=0))

    modules[2].C = torch.cat([torch.cat([parameters['growth']['params'][1], torch.zeros(implicit1_growth_count, 1, 1, device=device)], dim=1), torch.zeros(implicit1_rigid_count, 3, 1, device=device)], dim=0)


In [9]:
deformable_source = dm.Models.DeformableMesh(source_points, source_triangles.to(device))
deformable_target = dm.Models.DeformableMesh(target_points, target_triangles.to(device))

deformable_source.silent_module.to_(device)
deformable_target.silent_module.to_(device)

global_translation.to_(device)
implicit1.to_(device)
implicit1._ImplicitModule1_KeOps__keops_backend = 'GPU'


sigmas_varifold = [0.5, 5.]
attachment = dm.Attachment.VarifoldAttachment(3, sigmas_varifold)
# attachment = dm.Attachment.EuclideanPointwiseDistanceAttachment()

model = dm.Models.RegistrationModel(deformable_source, [global_translation, implicit1], [attachment], lam=10., precompute_callback=precompute, other_parameters={'growth': {'params': [angles, growth_constants]}})

In [None]:
shoot_solver = 'euler'
shoot_it = 10

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

fitter.fit(deformable_target, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

Starting optimization with method torch LBFGS
Initial cost={'deformation': tensor(0., device='cuda:2'), 'attach': tensor(497311.5625, device='cuda:2')}


Evaluated model with costs=497311.5625


Evaluated model with costs=496969.5448947996


Evaluated model with costs=487868.08416748047


Evaluated model with costs=202743.22521972656


Evaluated model with costs=89646.51530456543


Evaluated model with costs=89352.97537231445


Evaluated model with costs=89230.49766540527


Evaluated model with costs=485671.0199584961


Evaluated model with costs=89239.44175720215


Evaluated model with costs=89218.3910369873


Evaluated model with costs=77851.50700378418


Evaluated model with costs=59138.6805267334


Evaluated model with costs=55845.73417663574


Evaluated model with costs=55831.81832885742


Evaluated model with costs=55825.9743347168


In [11]:
import time

intermediates = {}
start = time.perf_counter()
with torch.autograd.no_grad():
    deformed = model.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0].detach()
print("Elapsed={elapsed}".format(elapsed=time.perf_counter()-start))

basis = compute_basis(angles.detach())

Elapsed=18.69181970274076


In [13]:
print(growth_constants)
print(angles)

tensor([[[ 0.3762],
         [ 0.7578]],

        [[ 0.8234],
         [ 0.8225]],

        [[ 1.2963],
         [ 1.0161]],

        [[ 0.8264],
         [ 0.4373]],

        [[-0.0755],
         [ 0.3636]],

        [[ 1.0501],
         [ 0.3886]],

        [[ 1.4648],
         [ 0.6037]],

        [[ 1.8374],
         [ 0.8068]],

        [[ 0.6798],
         [ 0.5140]],

        [[-0.3502],
         [ 0.3516]],

        [[ 0.5362],
         [ 0.1632]],

        [[ 1.1702],
         [ 0.2009]],

        [[ 1.9263],
         [ 0.5369]],

        [[ 0.5133],
         [ 0.7300]],

        [[ 0.0321],
         [ 0.4391]],

        [[ 0.1649],
         [ 0.1549]],

        [[ 0.6104],
         [ 0.1956]],

        [[ 1.5387],
         [ 0.4475]],

        [[ 0.5610],
         [ 0.8943]],

        [[ 0.3611],
         [ 0.4776]],

        [[ 0.7860],
         [ 0.0396]],

        [[ 0.8698],
         [ 0.1641]],

        [[ 1.3597],
         [ 0.4572]],

        [[ 0.6552],
         [ 0.8

In [12]:
meshio.write_points_cells("results_brain_trajectory/deformed_brain.ply", deformed.cpu(), [('triangle', source_triangles)])

  "PLY doesn't support 64-bit integers. Casting down to 32-bit."
