In [1]:
from plyfile import PlyData, PlyElement

In [2]:
import torch
import numpy as np

In [13]:
%load_ext autoreload
%autoreload 2

import sys
import copy
import math
import pickle

sys.path.append("../../")

import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import meshio

import imodal as dm

torch.set_default_tensor_type(torch.FloatTensor)

#dm.Utilities.set_compute_backend('keops')
#device = 'cuda:2'
device = 'cpu'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
plydata_source = PlyData.read('/home/gris/Data/ferret/ferret/F07_P4_dec_0_8.ply')
plydata_target = PlyData.read('/home/gris/Data/ferret/ferret/F10_P8_dec_0_8.ply')

In [15]:
X_source = plydata_source.elements[0]['x']
Y_source = plydata_source.elements[0]['y']
Z_source = plydata_source.elements[0]['z']

X_target = plydata_target.elements[0]['x']
Y_target = plydata_target.elements[0]['y']
Z_target = plydata_target.elements[0]['z']

In [16]:
faces_source = plydata_source.elements[1]['vertex_indices']
faces_target = plydata_target.elements[1]['vertex_indices']

In [17]:
source_triangles = torch.stack([torch.tensor(face, dtype=torch.long) for face in faces_source])
target_triangles = torch.stack([torch.tensor(face, dtype=torch.long) for face in faces_target])

In [18]:
source_points = torch.tensor([X_source, Y_source, Z_source]).t()
target_points = torch.tensor([X_target, Y_target, Z_target]).t()


In [19]:
source_points = source_points - torch.mean(source_points, dim=0)

In [20]:
area = lambda x, shape_ext, shape_int: dm.Utilities.area_convex_shape(x, shape=shape_ext) & ~dm.Utilities.area_convex_shape(x, shape=shape_int)

In [23]:
aabb_source = dm.Utilities.AABB.build_from_points(source_points)
implicit1_density = 2.5

implicit1_growth = source_points.clone()
#implicit1_growth = area(implicit1_points, 1.1*source_points, 0.9*source_points)
implicit1_rigid =0.8 * implicit1_growth
implicit1_points = torch.cat([implicit1_growth, implicit1_rigid])

implicit1_growth_count = implicit1_growth.shape[0]
implicit1_rigid_count = implicit1_rigid.shape[0]


implicit1_r = torch.empty(implicit1_points.shape[0], 3, 3)
implicit1_c = torch.empty(implicit1_points.shape[0], 3, 1)

growth_constants = torch.ones(implicit1_growth_count, 2, 1, requires_grad=True, device=device)
angles = torch.zeros(implicit1_growth_count, 3, requires_grad=True, device=device)

print(implicit1_points.shape[0])
print(implicit1_growth.shape)
print(implicit1_rigid.shape)

1250
torch.Size([625, 3])
torch.Size([625, 3])


In [24]:
def set_aspect_equal_3d(ax):
    """Fix equal aspect bug for 3D plots."""

    xlim = ax.get_xlim3d()
    ylim = ax.get_ylim3d()
    zlim = ax.get_zlim3d()

    from numpy import mean
    xmean = mean(xlim)
    ymean = mean(ylim)
    zmean = mean(zlim)

    plot_radius = max([abs(lim - mean_)
                       for lims, mean_ in ((xlim, xmean),
                                           (ylim, ymean),
                                           (zlim, zmean))
                       for lim in lims])

    ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
    ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])

In [25]:
%matplotlib qt5
fig = plt.figure()
step = 1
ax = fig.add_subplot(111, projection='3d')
ax.scatter(source_points[::step, 0].numpy(), source_points[::step, 1].numpy(), source_points[::step, 2].numpy(), color='blue')
set_aspect_equal_3d(ax)
ax.scatter(implicit1_rigid[::step, 0].numpy(), implicit1_rigid[::step, 1].numpy(), implicit1_rigid[::step, 2].numpy(), color='red')
set_aspect_equal_3d(ax)

#ax = fig.add_subplot(122, projection='3d')
#ax.scatter(target_pts[:, 0].numpy(), target_pts[:, 1].numpy(), target_pts[:, 2].numpy(), color='red')
set_aspect_equal_3d(ax)
plt.show()

In [26]:
def compute_basis(angles):
    rot_x = dm.Utilities.rot3d_x_vec(angles[:, 0])
    rot_y = dm.Utilities.rot3d_y_vec(angles[:, 1])
    rot_z = dm.Utilities.rot3d_z_vec(angles[:, 2])
    return torch.einsum('nik, nkl, nlj->nij', rot_z, rot_y, rot_x)

In [27]:
def precompute(init_manifold, modules, parameters):
    init_manifold[2].gd = (init_manifold[2].gd[0], torch.cat([compute_basis(parameters['growth']['params'][0]), torch.eye(3, device=device).repeat(implicit1_rigid_count, 1, 1)], dim=0))

    modules[2].C = torch.cat([torch.cat([parameters['growth']['params'][1], torch.zeros(implicit1_growth_count, 1, 1, device=device)], dim=1), torch.zeros(implicit1_rigid_count, 3, 1, device=device)], dim=0)


In [30]:
sigma = 1.

global_translation = dm.DeformationModules.GlobalTranslation(3)

implicit1 = dm.DeformationModules.ImplicitModule1(3, implicit1_points.shape[0], sigma, implicit1_c, nu=10., gd=(implicit1_points, implicit1_r))

In [31]:
deformable_source = dm.Models.DeformableMesh(source_points, source_triangles.to(device))
deformable_target = dm.Models.DeformableMesh(target_points, target_triangles.to(device))

deformable_source.silent_module.to_(device)
deformable_target.silent_module.to_(device)

global_translation.to_(device)
implicit1.to_(device)
#implicit1._ImplicitModule1_KeOps__keops_backend = 'GPU'


sigmas_varifold = [0.5, 5.]
attachment = dm.Attachment.VarifoldAttachment(3, sigmas_varifold)
# attachment = dm.Attachment.EuclideanPointwiseDistanceAttachment()

model = dm.Models.RegistrationModel(deformable_source, [global_translation, implicit1], [attachment], lam=10., precompute_callback=precompute, other_parameters={'growth': {'params': [angles, growth_constants]}})

In [32]:
shoot_solver = 'euler'
shoot_it = 10

In [None]:

costs = {}
fitter = dm.Models.Fitter(model, optimizer='torch_lbfgs')

fitter.fit(deformable_target, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

In [33]:
model.evaluate(deformable_target, shoot_solver, shoot_it)

KeyboardInterrupt: 