In [None]:
import os
#os.environ['OMP_NUM_THREADS'] = "6"
#os.environ['OMP_PLACES'] = "{0:6:1}"
#os.environ['KMP_AFFINITY'] = "granularity=fine,compact,1,0"

import sys
import copy
import math
import pickle

sys.path.append("../../")

import numpy as np
import torch
import matplotlib.pyplot as plt
#torch.set_num_threads(6)
import implicitmodules.torch as dm

In [None]:
# First load the deformed peanuts
dataset = pickle.load(open("../../data/peanuts.pickle", 'rb'))

peanuts = [torch.tensor(p[:-1], dtype=torch.get_default_dtype()) for p in dataset[0]]
#peanuts = [p - torch.mean(p, dim=0) for p in peanuts]

template = peanuts[0]
peanuts = peanuts[1:]
print("Number of peanuts:", len(peanuts))

In [None]:
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--')
for p in peanuts:
    plt.plot(p[:, 0].numpy(), p[:, 1].numpy())
plt.axis('equal')
plt.show()

In [None]:
# A simple polynomial model of order 1
def pol(pos, a, b, c, d, e, f, g, h, i, j):
    return a + b*pos[:, 0] + c*pos[:, 1] + d*pos[:, 0]**2 + e*pos[:, 1]**2 + f*pos[:, 0]*pos[:, 1] + g*pos[:, 0]**3 + h*pos[:, 1]**3 + i*pos[:, 0]*pos[:, 1]**2 + j*pos[:, 1]*pos[:, 0]**2

def sigmoid(pos, scale0, scale1, trans0, trans1, lambda0, lambda1, offset0, offset1):
    C = torch.zeros(pos.shape[0], 2, 2)
    C[:, 0, 0] = scale0*torch.sigmoid(lambda0*(pos[:, 0]+trans0)) + offset0
    C[:, 1, 0] = scale0*torch.sigmoid(lambda0*(pos[:, 0]+trans0)) + offset0
    C[:, 0, 1] = scale1*torch.sigmoid(lambda1*(pos[:, 0]+trans1)) + offset1
    C[:, 1, 1] = scale1*torch.sigmoid(lambda1*(pos[:, 0]+trans1)) + offset1
    return C

# def myParametricModel(init_manifold, modules, parameters):
#     abc = parameters[-1]
#     a = abc[0].unsqueeze(1)
#     b = abc[1].unsqueeze(1)
#     c = abc[2].unsqueeze(1)
#     d = abc[3].unsqueeze(1)
#     e = abc[4].unsqueeze(1)
#     f = abc[5].unsqueeze(1)
#     g = abc[6].unsqueeze(1)
#     h = abc[7].unsqueeze(1)
#     i = abc[8].unsqueeze(1)
#     j = abc[9].unsqueeze(1)
#     pos = modules[1].manifold.gd[0].detach().view(-1, 2)

#     modules[1]._ImplicitModule1__C = pol(pos, a, b, c, d, e, f, g, h, i, j).transpose(0, 1).unsqueeze(2)

def myParametricModel(init_manifold, modules, parameters):
    param = parameters[-1]
    scales = param[0]
    trans = param[1]
    lambdas = param[2]
    offset = param[3]
    pos = modules[1].manifold.gd[0].detach().view(-1, 2)
    modules[1]._ImplicitModule1__C = sigmoid(pos, scales[0], scales[1], trans[0], trans[1], lambdas[0], lambdas[1], offset[0], offset[1])

In [None]:
aabb_template = dm.Utilities.AABB.build_from_points(template)
aabb_template.scale(2.)
aabb_total = dm.Utilities.AABB.build_from_points(torch.cat(peanuts))

density = 10
sigma = 2./math.sqrt(density)

area = lambda x, **kwargs: dm.Utilities.area_shape(x, shape=template, side=1) | dm.Utilities.area_polyline_outline(x, polyline=template, width=sigma)

pts_implicit1 = dm.Utilities.fill_area_uniform_density(area, aabb_template, density, shape=template, side=1)
print(pts_implicit1.shape)
angles = 0. * math.pi * torch.ones(pts_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(a) for a in angles])

In [None]:
plt.plot(template[:, 0].numpy(), template[:, 1].numpy())
plt.plot(pts_implicit1[:, 0].numpy(), pts_implicit1[:, 1].numpy(), '.')
plt.show()

In [None]:
C = torch.ones(pts_implicit1.shape[0], 2, 2)
#C = sigmoid(pts_implicit1, 1., 1., 0., 0., 20., -20.)
implicit1 = dm.DeformationModules.ImplicitModule1.build_and_fill(2, pts_implicit1.shape[0], C, sigma, 0.001, coeff=0.5, gd=(pts_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_()))

In [None]:
%matplotlib qt5

ax = plt.subplot(1, 2, 1)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy())
dm.Utilities.plot_C_arrow(ax, pts_implicit1, C.detach().numpy(), c_index=0, scale=0.25)
plt.axis('equal')
ax = plt.subplot(1, 2, 2)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy())
dm.Utilities.plot_C_arrow(ax, pts_implicit1, C.detach().numpy(), c_index=1, scale=0.25)
plt.axis('equal')
plt.show()

In [None]:
sigma0 = 0.2
nu0 = 0.01
implicit0 = dm.DeformationModules.ImplicitModule0.build_from_points(2, template.shape[0], sigma0, nu0, coeff=5., gd=template.view(-1).requires_grad_())

abc = torch.zeros(4, 2)
abc[0] = torch.ones(2)
abc[2] = 2.*torch.tensor([1., -1.])
abc.requires_grad_()

In [None]:
my_atlas = dm.Models.Atlas(template, [implicit1, implicit0, dm.DeformationModules.GlobalTranslation(2)], [dm.Attachment.VarifoldAttachement([0.1, 0.5, 1.2])], len(peanuts), lam=100., model_precompute_callback=myParametricModel, other_parameters=[abc], optimise_template=False)

In [None]:
shoot_it = 20
shoot_method = 'euler'

In [None]:
fitter = dm.Models.ModelFittingScipy(my_atlas, 1.)

costs = fitter.fit(peanuts, 400, options={'shoot_it': shoot_it, 'shoot_method': shoot_method}, log_interval=1)

In [None]:
print(my_atlas.parameters[-1])

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
C_gt = torch.ones(pts_implicit1.shape[0], 2, 1)
C_gt[:, 0 , 0] = 1.*(pts_implicit1[:, 0]+torch.min(pts_implicit1[:, 0]))

In [None]:
%matplotlib qt5
ax = plt.subplot(1, 2, 1)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy())
dm.Utilities.plot_C_arrow(ax, pts_implicit1, implicit1.C.detach(), c_index=0, alpha=0.3, scale=0.05, color='blue', mutation_scale=5.)
#dm.Utilities.plot_C_arrow(ax, pts_implicit1, C_gt, alpha=0.3, scale=0.3, color='red', mutation_scale=5.)
plt.axis('equal')

ax = plt.subplot(1, 2, 2)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy())
dm.Utilities.plot_C_arrow(ax, pts_implicit1, implicit1.C.detach(), c_index=1, alpha=0.3, scale=0.05,  color='blue', mutation_scale=5.)
#dm.Utilities.plot_C_ellipse(ax, pts_implicit1, C_gt, c_index=1, alpha=0.1, scale=0.01, color='C2')
plt.axis('equal')
plt.show()

In [None]:
%matplotlib qt5
it_per_snapshot = 1
snapshots = int(shoot_it/it_per_snapshot)

#ht = my_atlas.compute_template().detach().view(-1, 2)

N = 10

for i in range(N):
    modules = dm.DeformationModules.CompoundModule(my_atlas.models[i].modules)
    modules.manifold.fill(my_atlas.models[i].init_manifold, copy=True)
    h = dm.HamiltonianDynamic.Hamiltonian(modules)
    intermediate_states, _ = dm.HamiltonianDynamic.shoot(h, shoot_it, shoot_method, intermediates=True)

    for j in range(snapshots):
        pos = intermediate_states[it_per_snapshot*j].gd[0].view(-1, 2).numpy()
        #pos_impl1 = intermediate_states[it_per_snapshot*j].gd[1][0].view(-1, 2).numpy()
        plt.subplot(N, snapshots + 1, i*snapshots + j + i + 1)
        plt.plot(pos[:, 0], pos[:, 1])
        #plt.plot(pos_impl1[:, 0], pos_impl1[:, 1], '.')
        plt.axis("equal")


    plt.subplot(N, snapshots + 1, i*snapshots + snapshots + i + 1)
    plt.plot(peanuts[i].numpy()[:, 0], peanuts[i].numpy()[:, 1])
    plt.axis("equal")