In [1]:
from plyfile import PlyData, PlyElement

In [2]:
import torch
import numpy as np

In [4]:
%load_ext autoreload
%autoreload 2

import sys
import copy
import math
import pickle

sys.path.append("../../")

import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import meshio

import imodal as dm

torch.set_default_tensor_type(torch.FloatTensor)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
plydata_source = PlyData.read('/home/gris/Data/ferret/ferret/F07_P4_dec_0_8.ply')
plydata_target = PlyData.read('/home/gris/Data/ferret/ferret/F10_P8_dec_0_8.ply')

In [6]:
X_source = plydata_source.elements[0]['x']
Y_source = plydata_source.elements[0]['y']
Z_source = plydata_source.elements[0]['z']

X_target = plydata_target.elements[0]['x']
Y_target = plydata_target.elements[0]['y']
Z_target = plydata_target.elements[0]['z']

In [7]:
faces_source = plydata_source.elements[1]['vertex_indices']
faces_target = plydata_target.elements[1]['vertex_indices']

In [8]:
source_pts = torch.tensor([X_source, Y_source, Z_source]).t()
target_pts = torch.tensor([X_target, Y_target, Z_target]).t()


In [9]:
source_pts.shape

torch.Size([625, 3])

In [10]:
target_pts.shape

torch.Size([1187, 3])

In [11]:
def set_aspect_equal_3d(ax):
    """Fix equal aspect bug for 3D plots."""

    xlim = ax.get_xlim3d()
    ylim = ax.get_ylim3d()
    zlim = ax.get_zlim3d()

    from numpy import mean
    xmean = mean(xlim)
    ymean = mean(ylim)
    zmean = mean(zlim)

    plot_radius = max([abs(lim - mean_)
                       for lims, mean_ in ((xlim, xmean),
                                           (ylim, ymean),
                                           (zlim, zmean))
                       for lim in lims])

    ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
    ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])

In [12]:
%matplotlib qt5
fig = plt.figure()
step = 1
ax = fig.add_subplot(121, projection='3d')
ax.scatter(source_pts[::step, 0].numpy(), source_pts[::step, 1].numpy(), source_pts[::step, 2].numpy(), color='blue')
set_aspect_equal_3d(ax)

#ax = fig.add_subplot(122, projection='3d')
ax.scatter(target_pts[:, 0].numpy(), target_pts[:, 1].numpy(), target_pts[:, 2].numpy(), color='red')
set_aspect_equal_3d(ax)
plt.show()

In [12]:
from evtk.hl import pointsToVTK


In [13]:
source_pts.dtype

torch.float32

In [74]:
pointsToVTK("/home/gris/Data/ferret/ferret/test", source_pts[:,0].numpy(),source_pts[:,1].numpy(), source_pts[:,2].numpy())

'/home/gris/Data/ferret/ferret/test.vtk.vtu'

In [14]:
xmin = min(source_pts[:,0])
ymin = min(source_pts[:,1])
zmin = min(source_pts[:,2])
xmax = max(source_pts[:,0])
ymax = max(source_pts[:,1])
zmax = max(source_pts[:,2])

In [15]:
AABB = dm.Utilities.AABB(xmin, xmax, ymin, ymax, zmin, zmax)

In [16]:
density = 0.5
grid = AABB.fill_uniform_density(density)

In [17]:
grid.shape

torch.Size([504, 3])

In [18]:
%matplotlib qt5
fig = plt.figure()
step = 1
ax = fig.add_subplot(121, projection='3d')
ax.scatter(source_pts[::step, 0].numpy(), source_pts[::step, 1].numpy(), source_pts[::step, 2].numpy(), color='blue')

stepgrid = 1
ax.scatter(grid[::stepgrid, 0].numpy(), grid[::stepgrid, 1].numpy(), grid[::stepgrid, 2].numpy(), color='red')
set_aspect_equal_3d(ax)

In [19]:
sigma_trans = 2.
coefftrans = 1.
translations = dm.DeformationModules.ImplicitModule0(3, grid.shape[0], sigma_trans, nu=0.1, gd=grid.clone().requires_grad_(), coeff=coefftrans)

In [24]:
points_growth = grid.clone().requires_grad_()
C = torch.zeros(points_growth.shape[0], 3, 1)
C[:, 0, 0] = 1.
C[:, 1, 0] = 1.
rot_growth = torch.stack([dm.Utilities.rot2d(0.)]*points_growth.shape[0], axis=0)


AttributeError: module 'implicitmodules.torch.Utilities' has no attribute 'rot3d'

In [None]:
scale_growth = 2.
growth = dm.DeformationModules.ImplicitModule1(
    2, points_growth.shape[0], scale_growth, C, coeff=coeff_growth, nu=nu,
    gd=(points_growth, rot_growth))


In [20]:
sigmas_varifold = [5.]
attachment = dm.Attachment.VarifoldAttachment(3, sigmas_varifold)

In [21]:
trig_source = []
for i in range(faces_source.shape[0]):
    trig_source.append(torch.tensor(faces_source[i], dtype = torch.long))
trig_source = torch.stack(trig_source)

trig_target = []
for i in range(faces_target.shape[0]):
    trig_target.append(torch.tensor(faces_target[i], dtype = torch.long))
trig_target = torch.stack(trig_target)

In [22]:
source_deformable = dm.Models.DeformableMesh(source_pts,trig_source)
target_deformable = dm.Models.DeformableMesh(target_pts, trig_target)

In [205]:
lam = 50.
modelgrowth = dm.Models.RegistrationModel([source_deformable], [translations], [attachment], fit_gd=[False], lam=lam)
#modelgrowth = dm.Models.RegistrationModel([source_deformable], [global_translation, rotation, growth], [attachment], fit_gd=[False], lam=10.)
#modelgrowth = dm.Models.RegistrationModel([source_deformable], [global_translation, rotation, growth], [attachment], fit_gd=[False], lam=10.,precompute_callback=precompute, other_parameters={'angles': {'params': [angles]}})

In [208]:
shoot_solver = 'euler'
shoot_it = 5

costs = {}
fitter = dm.Models.Fitter(modelgrowth, optimizer='torch_lbfgs')
# fitter = dm.Models.Fitter(model, optimizer='scipy_l-bfgs-b')
fitter.fit([target_deformable], 5, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})


Starting optimization with method torch LBFGS
Initial cost={'deformation': tensor(65.7550), 'attach': tensor(2013.8672)}


KeyboardInterrupt: 

In [209]:
intermediates = {}
with torch.autograd.no_grad():
    deformed_source = modelgrowth.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)[0][0]


In [1]:
for i in range(shoot_it + 1):
    pts_t = np.ascontiguousarray(intermediates['states'][i].gd[0].detach().numpy())
    pointsToVTK("/home/gris/Results/ferret/translations_t_" + str(i), np.ascontiguousarray(pts_t[:,0]),np.ascontiguousarray(pts_t[:,1]), np.ascontiguousarray(pts_t[:,2]))

NameError: name 'shoot_it' is not defined

In [216]:
pts_t

array([[-12.4932  ,   1.98044 ,   2.53529 ],
       [-12.8487  ,   2.06904 ,   2.6987  ],
       [-12.5895  ,   1.76939 ,   2.80287 ],
       ...,
       [-18.0287  ,   0.996781,  -1.06306 ],
       [-12.2925  ,  -1.37977 ,   1.58686 ],
       [-12.2925  ,  -1.37977 ,   1.58686 ]], dtype=float32)