In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
data = []
with open("bendings.pickle", 'rb') as f:
    data = pickle.load(f)

template = data[0][0]
template = torch.unique_consecutive(template, dim=0)


true_implicit1_points = data[0][1]
true_C = data[0][2]
bendings = data[1:]

dataset = [torch.unique_consecutive(target, dim=0) for target in list(zip(*bendings))[0]]

targets = dataset[:15]

print("Dataset size: {size}".format(size=len(targets)))

In [None]:
%matplotlib qt5
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='black', lw=2.)
for target in targets:
    plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), color='grey', lw=0.5)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='black', lw=2.)
plt.plot(true_implicit1_points[:, 0].numpy(), true_implicit1_points[:, 1].numpy(), 'x')
plt.axis('equal')
plt.show()

In [None]:
sigma_implicit1 = 1.5

implicit1_points = true_implicit1_points
implicit1_R = torch.eye(2).repeat(implicit1_points.shape[0], 1, 1)
C_init = torch.ones(implicit1_points.shape[0], 2, 1)
implicit1 = dm.DeformationModules.ImplicitModule1(2, true_implicit1_points.shape[0], sigma_implicit1, C_init, nu=0.1, gd=(implicit1_points.clone().requires_grad_(), implicit1_R.clone().requires_grad_()))
abc_init = torch.zeros(3, 2)
abc_init[0] = torch.ones(2)

In [None]:
def pol_order_1(pos, a, b, c):
    return a + b*pos[:, 0] + c*pos[:, 1]

def callback_compute_c(init_manifold, modules, parameters):
    abc = parameters[-1]
    a = abc[0].unsqueeze(1)
    b = abc[1].unsqueeze(1)
    c = abc[2].unsqueeze(1)
    modules[1]._ImplicitModule1Base__C = pol_order_1(implicit1_points, a, b, c).transpose(0, 1).unsqueeze(2)

In [None]:
atlas = dm.Models.Atlas(template, [implicit1], [dm.Attachment.L2NormAttachment()], len(targets), lam=10000., other_parameters=[abc_init.clone().requires_grad_()], model_precompute_callback=callback_compute_c)

In [None]:
fitter = dm.Models.ModelFittingScipy(atlas, 1.)

#with torch.autograd.detect_anomaly():
costs = fitter.fit(targets, 150, options={'shoot_method': 'rk4', 'shoot_it': 10})

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
abc_fit = atlas.parameters[-1].detach()
print(abc_fit)
C_fit = pol_order_1(implicit1_points, abc_fit[0].unsqueeze(1), abc_fit[1].unsqueeze(1), abc_fit[2].unsqueeze(1)).t().unsqueeze(2)

print(torch.dot((true_C/torch.norm(true_C)).flatten(), (C_fit/torch.norm(C_fit)).flatten()))