In [None]:
import os
#os.environ['OMP_NUM_THREADS'] = "6"
#os.environ['OMP_PLACES'] = "{0:6:1}"
#os.environ['KMP_AFFINITY'] = "granularity=fine,compact,1,0"

import sys
import copy
import math
import pickle

sys.path.append("../../")

import numpy as np
import torch
import matplotlib.pyplot as plt
torch.set_num_threads(6)
import implicitmodules.torch as dm

In [None]:
# First load the deformed peanuts

peanuts = pickle.load(open("../../data/deformed_ellipses.pkl", 'rb'))

peanuts = [torch.tensor(p[:-1]) for p in peanuts['dataset']]

# template = torch.stack([torch.cos(torch.linspace(0., 2.*math.pi, len(peanuts[0]))),
#                        torch.sin(torch.linspace(0., 2.*math.pi, len(peanuts[0])))], dim=1)

template = peanuts[0]
peanuts = peanuts[1:]
print("Number of peanuts:", len(peanuts))

In [None]:
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--')
for p in peanuts:
    plt.plot(p[:, 0].numpy(), p[:, 1].numpy())
plt.axis('equal')
plt.show()

In [None]:
aabb_peanut = dm.Utilities.AABB.build_from_points(template)

density = 20
sigma = 3./math.sqrt(density)

pts_implicit1 = dm.Utilities.fill_aabb(aabb_peanut, density)
angles = 0. * math.pi * torch.ones(pts_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(a) for a in angles])

In [None]:
C = torch.ones(pts_implicit1.shape[0], 2, 1)
C.requires_grad_()
implicit1 = dm.DeformationModules.ImplicitModule1.build_and_fill(2, pts_implicit1.shape[0], C, sigma, 0.001, gd=(pts_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_()))

In [None]:
my_atlas = dm.Models.Atlas(template, [implicit1], [dm.Attachment.VarifoldAttachement([0.1, 1., 5.])], len(peanuts), other_parameters=[implicit1.C], use_hypertemplate=False)

In [None]:
shoot_it = 10
shoot_method = 'euler'

In [None]:
fitter = dm.Models.ModelFittingScipy(my_atlas, 1., 100.)

costs = fitter.fit(peanuts, 100, options={'shoot_it': shoot_it, 'shoot_method': shoot_method}, log_interval=1)

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
# C_gt = torch.zeros(pts_implicit1.shape[0], 2, 2)
# C_gt[pts_implicit1[:, 0] >= 0, 0, 0] = -10.
# C_gt[pts_implicit1[:, 0] >= 0, 1, 0] = 5.
# C_gt[pts_implicit1[:, 0] < 0, 0, 1] = 20.
# C_gt[pts_implicit1[:, 0] < 0, 1, 1] = 8.
C_gt = torch.ones(pts_implicit1.shape[0], 2, 1)
C_gt[:, 0 , 0] = torch.linspace(0., 10., pts_implicit1.shape[0])

#C_gt_pos = 

In [None]:
#template = my_atlas.compute_template()

%matplotlib qt5
ax = plt.subplot()
plt.plot(template[:, 0].numpy(), template[:, 1].numpy())
dm.Utilities.plot_C_ellipse(ax, pts_implicit1, implicit1.C.detach(), scale=0.1)
dm.Utilities.plot_C_ellipse(ax, pts_implicit1, C_gt, alpha=0.1, scale=0.03, color='C3')
#dm.Utilities.plot_C_ellipse(ax, pts_implicit1, C_gt, c_index=1, alpha=0.1, scale=0.01, color='C2')
plt.axis('equal')
plt.show()

In [None]:
%matplotlib qt5
it_per_snapshot = 2
snapshots = int(shoot_it/it_per_snapshot)

#ht = my_atlas.compute_template().detach().view(-1, 2)

for i in range(len(peanuts)):
    #implicit0_pos = my_atlas.models[0].init_manifold[1].gd.detach().clone().view(-1, 2)
    #implicit0 = dm.DeformationModules.ImplicitModule0.build_from_points(2, implicit0_pos.shape[0], sigma0, nu0, gd= implicit0_pos.view(-1).requires_grad_(), cotan=my_atlas.models[i].init_manifold[1].cotan)
    #implicit1 = dm.DeformationModules.ImplicitModule1(dm.Manifolds.Stiefel(2, template.shape[0], gd=(template.view(-1).requires_grad_(), R.view(-1).requires_grad_()), cotan=my_atlas.models[i].init_manifold[1].cotan), C_init, sigma1, nu1, coeff1)
    #silent = dm.DeformationModules.SilentLandmarks.build_from_points(ht)
    #silent.manifold.fill_cotan(my_atlas.models[i].init_manifold[0].cotan)
    modules = dm.DeformationModules.CompoundModule(my_atlas.models[i].modules)
    modules.manifold.fill(my_atlas.models[i].init_manifold, copy=True)
    #print(modules.manifold.cotan)
    h = dm.HamiltonianDynamic.Hamiltonian(modules)
    intermediate_states, _ = dm.HamiltonianDynamic.shoot(h, shoot_it, shoot_method, intermediates=True)

    for j in range(snapshots):
        pos = intermediate_states[it_per_snapshot*j].gd[0].view(-1, 2).numpy()
        pos_impl1 = intermediate_states[it_per_snapshot*j].gd[1][0].view(-1, 2).numpy()
        plt.subplot(len(peanuts), snapshots + 1, i*snapshots + j + i + 1)
        plt.plot(pos[:, 0], pos[:, 1])
        #plt.plot(pos_impl1[:, 0], pos_impl1[:, 1], '.')
        plt.axis("equal")


    plt.subplot(len(peanuts), snapshots + 1, i*snapshots + snapshots + i + 1)
    plt.plot(peanuts[i].numpy()[:, 0], peanuts[i].numpy()[:, 1])
    plt.axis("equal")