In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)
dm.Utilities.set_compute_backend('torch')

torch.manual_seed(1)

In [None]:
dim = 2
nb_pts = 3
sigma = 0.01
gd_pos = 0.1*torch.randn(nb_pts, dim)
gd_dir = 0.1*torch.randn(nb_pts, dim)

mom_pos = 0.5*torch.randn(nb_pts, dim)
mom_dir = 0.5*torch.randn(nb_pts, dim)

oriented = dm.DeformationModules.OrientedTranslations(dim, nb_pts, sigma, 'surface', gd=(gd_pos.clone().requires_grad_(), gd_dir.clone().requires_grad_()), cotan=(mom_pos.clone().requires_grad_(), mom_dir.clone().requires_grad_()))
print(oriented.manifold.gd)
print(oriented.manifold.cotan)

In [None]:
intermediate_states, intermediate_controls = dm.HamiltonianDynamic.shoot(dm.HamiltonianDynamic.Hamiltonian([oriented]), 100, 'euler', intermediates=True)

In [None]:
%matplotlib qt5
trajectories = [torch.stack(trajectory) for trajectory in list(zip(*(state[0].gd[0] for state in intermediate_states)))]

trajectories_dir = [torch.stack(trajectory_dir) for trajectory_dir in list(zip(*(state[0].gd[1] for state in intermediate_states)))]

In [None]:
aabb = dm.Utilities.AABB.build_from_points(torch.cat([trajectory for trajectory in trajectories])).scale(1.5)
oriented = dm.DeformationModules.OrientedTranslations(dim, nb_pts, sigma, 'surface', gd=(gd_pos.clone().requires_grad_(), gd_dir.clone().requires_grad_()), cotan=(mom_pos.clone().requires_grad_(), mom_dir.clone().requires_grad_()))

gd_grid = aabb.fill([4, 5])
nb_pts_silent = gd_grid.shape[0]
grid = dm.DeformationModules.SilentLandmarks(dim, nb_pts_silent, gd=gd_grid.requires_grad_())

In [None]:
dm.HamiltonianDynamic.shoot(dm.HamiltonianDynamic.Hamiltonian([oriented, grid]), 1, 'torch_euler')

In [None]:
ax = plt.subplot()

for trajectory, trajectory_dir in zip(trajectories, trajectories_dir):
    plt.plot(trajectory[:, 0], trajectory[:, 1], '-')
    plt.quiver(trajectory[:, 0], trajectory[:, 1], trajectory_dir[:, 0], trajectory_dir[:, 1], scale=20.)

defgrid_x, defgrid_y = dm.Utilities.vec2grid(grid.manifold.gd.detach(), 17, 32)
dm.Utilities.plot_grid(ax, defgrid_x, defgrid_y, color='blue')

plt.show()