In [None]:
%load_ext autoreload
%autoreload 2

import sys
import copy
import math
import pickle

sys.path.append("../../")

import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import meshio

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

dm.Utilities.set_compute_backend('torch')

In [None]:
mesh = meshio.read("../../data/sphere.stl")

source = torch.tensor(mesh.points).to(dtype=torch.get_default_dtype())
source = source[::5, :].contiguous()

target = source * torch.tensor([1., 1., 2.])

In [None]:
def set_aspect_equal_3d(ax):
    """Fix equal aspect bug for 3D plots."""

    xlim = ax.get_xlim3d()
    ylim = ax.get_ylim3d()
    zlim = ax.get_zlim3d()

    from numpy import mean
    xmean = mean(xlim)
    ymean = mean(ylim)
    zmean = mean(zlim)

    plot_radius = max([abs(lim - mean_)
                       for lims, mean_ in ((xlim, xmean),
                                           (ylim, ymean),
                                           (zlim, zmean))
                       for lim in lims])

    ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
    ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])

In [None]:
%matplotlib qt5
fig = plt.figure()

ax = fig.add_subplot(121, projection='3d')
ax.scatter(source[:, 0].numpy(), source[:, 1].numpy(), source[:, 2].numpy(), marker='o')
set_aspect_equal_3d(ax)

ax = fig.add_subplot(122, projection='3d')
ax.scatter(target[:, 0].numpy(), target[:, 1].numpy(), target[:, 2].numpy(), marker='o')
set_aspect_equal_3d(ax)

plt.show()

In [None]:
device = 'cpu'

source = source.to(device=device)
target = target.to(device=device)

nu = 0.01

sigma0 = 0.2
implicit0 = dm.DeformationModules.create_deformation_module('implicit_order_0', dim=3, nb_pts=source.shape[0], sigma=sigma0, nu=nu, coeff=1., gd=source.view(-1).requires_grad_())

mini, maxi = -1.2, 1.2
nb_pts = 5

grid_xyz = torch.meshgrid([torch.linspace(mini, maxi, nb_pts), torch.linspace(mini, maxi, nb_pts), torch.linspace(mini, maxi, nb_pts)])

pts_implicit1 = dm.Utilities.grid2vec(grid_xyz[0], grid_xyz[1], grid_xyz[2]).to(device=device)
R = torch.stack([torch.eye(3) for a in pts_implicit1]).to(device=device)
C = torch.tensor([[0., 0., 1.]]).repeat(pts_implicit1.shape[0], 1).unsqueeze(2).to(device=device)

sigma1 = 0.7

implicit1 = dm.DeformationModules.create_deformation_module('implicit_order_1', dim=3, nb_pts=pts_implicit1.shape[0], C=C, sigma=sigma1, nu=nu, gd=(pts_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_()))

In [None]:
model = dm.Models.ModelPointsRegistration([source], [implicit1], [dm.Attachment.EnergyAttachement()])
fitter = dm.Models.ModelFittingScipy(model, 10.)

costs = fitter.fit([target], 50, log_interval=1)

In [None]:
trans = model.modules[0].manifold.gd.to('cpu').detach().view(-1, 3)
target = target.to('cpu')    

%matplotlib qt5
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(trans[:, 0].numpy(), trans[:, 1].numpy(), trans[:, 2].numpy(), marker='o', color='b')
ax.scatter(target[:, 0].numpy(), target[:, 1].numpy(), target[:, 2].numpy(), marker='o', color='r')
set_aspect_equal_3d(ax)
plt.show()
