In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
# Lets load all peanuts
peanuts = []
raw_data = pickle.load(open("../../data/peanuts.pickle", "rb"))

raw_pickles = raw_data[0]
for i in range(len(raw_pickles)):
    peanuts.append(torch.tensor(raw_pickles[i], dtype=torch.get_default_dtype()))

n_peanuts = 10
template = peanuts[0][:-1]
peanuts = [p[:-1] for p in peanuts[-n_peanuts:]]

aabb_trans = dm.Utilities.AABB.build_from_points(torch.cat([p for p in peanuts]))

In [None]:
length = 15
implicit0_pos_x, implicit0_pos_y = torch.meshgrid([
    torch.linspace(aabb_trans.xmin, aabb_trans.xmax, length),
    torch.linspace(aabb_trans.ymin, aabb_trans.ymax, length)])

implicit0_pos = dm.Utilities.grid2vec(implicit0_pos_x, implicit0_pos_y)

In [None]:
for i in range(len(peanuts)):
    plt.plot(peanuts[i].numpy()[:, 0], peanuts[i].numpy()[:, 1])

plt.plot(template.numpy()[:, 0], template.numpy()[:, 1], "--")
plt.plot(implicit0_pos[:, 0].numpy(), implicit0_pos[:, 1].numpy(), ".")    

plt.show()

In [None]:
sigma0 = 0.5
nu0 = 0.01
implicit0 = dm.DeformationModules.ImplicitModule0.build_from_points(2, implicit0_pos.shape[0], sigma0, nu0, gd=implicit0_pos.clone().view(-1).requires_grad_())

In [None]:
sigma1 = 0.5
nu1 = 0.01
coeff1 = 0.001
th = 0. * math.pi * torch.ones(template.shape[0])
R = torch.stack([dm.Utilities.rot2d(t) for t in th])
C_init = torch.ones(template.shape[0], 2, 1)
implicit1 = dm.DeformationModules.ImplicitModule1(dm.Manifolds.Stiefel(2, template.shape[0], gd=(template.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C_init, sigma1, nu1, coeff1)

In [None]:
my_atlas = dm.Models.Atlas(template, [implicit0], [dm.Attachment.VarifoldAttachement([0.5])], len(peanuts), 0.3, fit_gd=[True])

my_fitter = dm.Models.ModelFittingScipy(my_atlas, 1., 50000.)

In [None]:
shoot_it = 10
shoot_method = "euler"
costs = my_fitter.fit(peanuts, 150, log_interval=1, options={"shoot_method": shoot_method, "shoot_it": shoot_it})

In [None]:
%matplotlib qt5
it_per_snapshot = 1
snapshots = int(shoot_it/it_per_snapshot)

ht = my_atlas.compute_template().detach().view(-1, 2)

for i in range(len(peanuts)):
    implicit0_pos = my_atlas.models[0].init_manifold[1].gd.detach().clone().view(-1, 2)
    implicit0 = dm.DeformationModules.ImplicitModule0.build_from_points(2, implicit0_pos.shape[0], sigma0, nu0, gd= implicit0_pos.view(-1).requires_grad_(), cotan=my_atlas.models[i].init_manifold[1].cotan)
    #implicit1 = dm.DeformationModules.ImplicitModule1(dm.Manifolds.Stiefel(2, template.shape[0], gd=(template.view(-1).requires_grad_(), R.view(-1).requires_grad_()), cotan=my_atlas.models[i].init_manifold[1].cotan), C_init, sigma1, nu1, coeff1)
    silent = dm.DeformationModules.SilentLandmarks.build_from_points(ht)
    silent.manifold.fill_cotan(my_atlas.models[i].init_manifold[0].cotan)
    h = dm.HamiltonianDynamic.Hamiltonian([silent, implicit0])
    intermediate_states, _ = dm.HamiltonianDynamic.shoot(h, shoot_it, shoot_method, intermediates=True)

    for j in range(snapshots):
        pos = intermediate_states[it_per_snapshot*j].gd[0].view(-1, 2).numpy()
        pos_impl1 = intermediate_states[it_per_snapshot*j].gd[1].view(-1, 2).numpy()
        plt.subplot(len(peanuts), snapshots + 1, i*snapshots + j + i + 1)
        plt.plot(pos[:, 0], pos[:, 1])
        plt.plot(pos_impl1[:, 0], pos_impl1[:, 1], '.')
        plt.axis("equal")


    plt.subplot(len(peanuts), snapshots + 1, i*snapshots + snapshots + i + 1)
    plt.plot(peanuts[i].numpy()[:, 0], peanuts[i].numpy()[:, 1])
    plt.axis("equal")