In [None]:
%load_ext autoreload
%autoreload 2

import sys
import math
import pickle

import torch
import matplotlib.pyplot as plt

sys.path.append("../../")

import implicitmodules.torch as dm

In [None]:
torch.set_default_dtype(torch.float32)

data = pickle.load(open("../../data/peanuts.pickle", 'rb'))

peanuts_count = 6
peanuts = [torch.tensor(peanut[:-1], dtype=torch.get_default_dtype()) for peanut in data[0][1:peanuts_count+1]]

template = dm.Utilities.generate_unit_circle(200)
template = dm.Utilities.linear_transform(template, torch.tensor([[1.3, 0.], [0., 0.5]]))
template = dm.Utilities.close_shape(template)

deformable_template = dm.Models.DeformablePoints(template.clone().requires_grad_(False))
deformable_peanuts = [dm.Models.DeformablePoints(peanut) for peanut in peanuts]

point_left_scale = torch.tensor([[-1., 0.]])
point_right_scale = torch.tensor([[1., 0.]])

In [None]:
%matplotlib qt5

plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='xkcd:blue')
plt.plot(point_left_scale[0, 0].numpy(), point_left_scale[0, 1].numpy(), 'x', color='xkcd:blue')
plt.plot(point_right_scale[0, 0].numpy(), point_right_scale[0, 1].numpy(), 'x', color='xkcd:blue')
for peanut in peanuts:
    plt.plot(peanut[:, 0].numpy(), peanut[:, 1].numpy(), lw=0.4, color='xkcd:light blue')

plt.axis('equal')
plt.show()

In [None]:
def generate_implicit1_gd():
    area = lambda x, **kwargs: dm.Utilities.area_shape(x, **kwargs) | dm.Utilities.area_polyline_outline(x, **kwargs)
    return dm.Utilities.fill_area_uniform_density(area, template_aabb.scale(1.3), 40., shape=template, polyline=template, width=0.2)

template_aabb = dm.Utilities.AABB.build_from_points(template)
#implicit_gd = dm.Utilities.fill_area_uniform_density(dm.Utilities.area_shape, template_aabb, 40., shape=template)

# implicit_gd = template_aabb.scale([1.5, 2.]).fill_uniform_density(40.)
implicit_gd = generate_implicit1_gd()
implicit_r = dm.Utilities.rot2d(0.).repeat(implicit_gd.shape[0], 1, 1)
implicit_c = torch.randn(implicit_gd.shape[0], 2, 2) + 1.
print(implicit_c.shape)

In [None]:
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='xkcd:blue')
plt.plot(implicit_gd[:, 0].numpy(), implicit_gd[:, 1].numpy(), 'o')
plt.axis('equal')
plt.show()

In [None]:
implicit1_scale = 0.35

implicit1 = dm.DeformationModules.ImplicitModule1(2, implicit_gd.shape[0], implicit1_scale, implicit_c.clone().requires_grad_(), nu=0.01, gd=(implicit_gd, implicit_r))
global_translation = dm.DeformationModules.GlobalTranslation(2)

In [None]:
sigmas_varifold = [0.4, 2.5]
attachment = dm.Attachment.VarifoldAttachment(2, sigmas_varifold)

def precompute(init_manifold, modules, parameters):
    modules[2].C = parameters['C']['params'][0]

atlas = dm.Models.AtlasModel(deformable_template, [global_translation, implicit1], [attachment], len(peanuts), lam=100., optimise_template=True, ht_sigma=0.4, ht_it=10, ht_coeff=.5, ht_nu=0.05, fit_gd=None, other_parameters={'C': {'params': [implicit1.C]}}, model_precompute_callback=precompute)


In [None]:
shoot_solver = 'euler'
shoot_it = 10
# print(hex(id(implicit1.C)))
# print(hex(id(atlas.parameters['C']['params'][0])))
# print(hex(id(atlas.registration_models[0].modules[2].C)))
# print(hex(id(atlas.registration_models[1].modules[2].C)))
# print(hex(id(atlas.registration_models[2].modules[2].C)))
costs = {}
fitter = dm.Models.Fitter(atlas, optimizer='torch_lbfgs')

fitter.fit(deformable_peanuts, 20, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})

In [None]:
ht = atlas.compute_template()[0].detach()

learned_c = implicit1.C.detach()

plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--')
plt.plot(ht[:, 0].numpy(), ht[:, 1].numpy())

plt.axis('equal')
plt.show()

In [None]:
var_c = (implicit_c - learned_c)/learned_c
print(torch.mean(torch.abs(var_c)))

In [None]:
intermediates = {}
with torch.autograd.no_grad():
    deformed_templates = atlas.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)

row_count = math.ceil(math.sqrt(len(peanuts)))

for i, deformed, peanut in zip(range(len(peanuts)), deformed_templates, peanuts):
    plt.subplot(row_count, row_count, 1 + i)
    plt.plot(ht[:, 0].numpy(), ht[:, 1].numpy(), color='xkcd:light blue', lw=0.5)
    plt.plot(deformed[0].detach()[:, 0].numpy(), deformed[0].detach()[:, 1].numpy(), color='black')
    plt.plot(peanut[:, 0].numpy(), peanut[:, 1].numpy())
    plt.axis('equal')

plt.show()

In [None]:
ax = plt.subplot(2, 2, 1)
plt.plot(implicit_gd[:, 0].numpy(), implicit_gd[:, 1].numpy(), '.')
dm.Utilities.plot_C_arrows(ax, implicit_gd, implicit_c, c_index=0, color='blue', mutation_scale=10., scale=0.1)
plt.axis('equal')

ax = plt.subplot(2, 2, 2)
plt.plot(implicit_gd[:, 0].numpy(), implicit_gd[:, 1].numpy(), '.')
dm.Utilities.plot_C_arrows(ax, implicit_gd, implicit_c, c_index=1, color='blue', mutation_scale=10., scale=0.1)
plt.axis('equal')

ax = plt.subplot(2, 2, 3)
plt.plot(implicit_gd[:, 0].numpy(), implicit_gd[:, 1].numpy(), '.')
dm.Utilities.plot_C_arrows(ax, implicit_gd, learned_c, c_index=0, color='blue', mutation_scale=10., scale=0.1)
plt.axis('equal')

ax = plt.subplot(2, 2, 4)
plt.plot(implicit_gd[:, 0].numpy(), implicit_gd[:, 1].numpy(), '.')
dm.Utilities.plot_C_arrows(ax, implicit_gd, learned_c, c_index=1, color='blue', mutation_scale=10., scale=0.1)

plt.axis('equal')
plt.show()