In [1]:
%load_ext autoreload
%autoreload 2

import sys
import copy
import math
import pickle

sys.path.append("../../")

import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import meshio

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

dm.Utilities.set_compute_backend('keops')
device = 'cuda:0'

In [2]:
source_mesh = meshio.read("../../data/sphere.stl")
target_mesh = meshio.read("../../data/unit_cube.stl")

source_pts = torch.tensor(source_mesh.points).to(dtype=torch.get_default_dtype())

source_faces = torch.tensor(source_mesh.cells['triangle'])

target_pts = torch.tensor(target_mesh.points).to(dtype=torch.get_default_dtype()).abs()
target_pts = 2.*(target_pts - 0.5)
target_faces = torch.tensor(target_mesh.cells['triangle'])

TypeError: list indices must be integers or slices, not str

In [None]:
def set_aspect_equal_3d(ax):
    """Fix equal aspect bug for 3D plots."""

    xlim = ax.get_xlim3d()
    ylim = ax.get_ylim3d()
    zlim = ax.get_zlim3d()

    from numpy import mean
    xmean = mean(xlim)
    ymean = mean(ylim)
    zmean = mean(zlim)

    plot_radius = max([abs(lim - mean_)
                       for lims, mean_ in ((xlim, xmean),
                                           (ylim, ymean),
                                           (zlim, zmean))
                       for lim in lims])

    ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
    ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])

In [None]:
%matplotlib qt5
fig = plt.figure()

ax = fig.add_subplot(121, projection='3d')
ax.scatter(source_pts[:, 0].numpy(), source_pts[:, 1].numpy(), source_pts[:, 2].numpy(), marker='o')
set_aspect_equal_3d(ax)

ax = fig.add_subplot(122, projection='3d')
ax.scatter(target_pts[:, 0].numpy(), target_pts[:, 1].numpy(), target_pts[:, 2].numpy(), marker='o')
set_aspect_equal_3d(ax)

plt.show()

In [None]:
source_pts = source_pts.to(device=device)
target_pts = target_pts.to(device=device)
                          
source_faces = source_faces.to(device=device)                          
target_faces = target_faces.to(device=device)

nu = 0.01

sigma0 = 0.2

mini, maxi = -1.1, 1.1
nb_pts = 20

grid_xyz = torch.meshgrid([torch.linspace(mini, maxi, nb_pts), torch.linspace(mini, maxi, nb_pts), torch.linspace(mini, maxi, nb_pts)])

pts_implicit0 = dm.Utilities.grid2vec(grid_xyz[0], grid_xyz[1], grid_xyz[2]).to(device=device)

implicit0 = dm.DeformationModules.ImplicitModule0(3, pts_implicit0.shape[0], sigma0, nu=nu, coeff=1., gd=pts_implicit0.requires_grad_())

implicit0.to_(device)

In [None]:
model = dm.Models.ModelPointsRegistration([(source_pts, source_faces)], [implicit0], [dm.Attachment.VarifoldAttachment(3, [1.])], lam=100.)
fitter = dm.Models.ModelFittingScipy(model)

In [None]:
costs = fitter.fit([(target_pts, target_faces)], 30, log_interval=1)

In [None]:
trans = model.modules[0].manifold.gd.to('cpu').detach().view(-1, 3)
target = target_pts.to('cpu')

%matplotlib qt5
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#ax.scatter(trans[:, 0].numpy(), trans[:, 1].numpy(), trans[:, 2].numpy(), marker='o', color='b')
#ax.scatter(target[:, 0].numpy(), target[:, 1].numpy(), target[:, 2].numpy(), marker='o', color='r')
ax.plot_trisurf(trans[:, 0].numpy(), trans[:, 1].numpy(), trans[:, 2].numpy(), triangles=source_faces.cpu())
ax.plot_trisurf(target[:, 0].numpy(), target[:, 1].numpy(), target[:, 2].numpy(), triangles=target_faces.cpu(), color= (0,1,0,0.2), edgecolor=(0, 1, 0, 0.2))
set_aspect_equal_3d(ax)
plt.show()
