In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
data = pickle.load(open("herd.pickle", 'rb'))
print(data[0])
ear_pos_mean = list(torch.mean(torch.stack(list(zip(*list(zip(*data))[1]))[i], dim=0), dim=0) for i in range(2))
ear_pos_stdvar = list(torch.sqrt(torch.mean(torch.var(torch.stack(list(zip(*list(zip(*data))[1]))[i], dim=0), dim=0))) for i in range(2))

# We need a better template than that
template = data[0][0]
herd = list(list(zip(*data))[0])[:-1]

ear_sigma = 1.

In [None]:
%matplotlib inline
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='black', lw=4.)
for bunny in herd:
    plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '-', color='C4')
plt.plot(ear_pos_mean[0][0].numpy(), ear_pos_mean[0][1].numpy(), 'X')
plt.plot(ear_pos_mean[1][0].numpy(), ear_pos_mean[1][1].numpy(), 'X')
plt.axis('equal')
plt.show()


In [None]:
right_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, gd=(ear_pos_mean[0].unsqueeze(0).requires_grad_(), (ear_pos_mean[0].unsqueeze(0)/torch.norm(ear_pos_mean[0])).requires_grad_()), label='right_ear_translation')
left_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, gd=(ear_pos_mean[1].unsqueeze(0).requires_grad_(), (ear_pos_mean[0].unsqueeze(0)/torch.norm(ear_pos_mean[1])).requires_grad_()), label='left_ear_translation')
#rigid_rotation = dm.DeformationModules.LinearDeformation.build(torch.tensor([[0., -1.], [1., 0.]]), gd=torch.tensor([[0., 0.]], requires_grad=True))
rigid_translation = dm.DeformationModules.GlobalTranslation(2)

targets = herd

atlas = dm.Models.Atlas(template, [right_ear_translation, left_ear_translation, rigid_translation], [dm.Attachment.VarifoldAttachment(2, [0.3])], len(targets), fit_gd=[True, True, True, False], lam=100.)

In [None]:
fitter = dm.Models.ModelFittingScipy(atlas, 1.)
costs = fitter.fit(targets, 500)

In [None]:
%matplotlib qt5
plt.plot(range(len(costs)), costs)
plt.show()

In [None]:
print(ear_pos_mean[0])
print(right_ear_translation.manifold.gd)
print(ear_pos_mean[1])
print(left_ear_translation.manifold.gd)


In [None]:
%matplotlib qt5
shoot_it = 10
shoot_method = 'euler'
it_per_snapshot = 2
snapshots = int(shoot_it/it_per_snapshot)

for i in range(len(targets)):
    silent_pos = atlas.models[i].init_manifold[0].gd.detach().clone()
    silent_mom = atlas.models[i].init_manifold[0].cotan.detach().clone()
    right_ear_trans_gd = (atlas.models[i].init_manifold[1].gd[0].detach().clone(),
                          atlas.models[i].init_manifold[1].gd[1].detach().clone())
    right_ear_trans_mom = (atlas.models[i].init_manifold[1].cotan[0].detach().clone(),
                           atlas.models[i].init_manifold[1].cotan[1].detach().clone())
    left_ear_trans_gd = (atlas.models[i].init_manifold[2].gd[0].detach().clone(),
                          atlas.models[i].init_manifold[2].gd[1].detach().clone())
    left_ear_trans_mom = (atlas.models[i].init_manifold[2].cotan[0].detach().clone(),
                           atlas.models[i].init_manifold[2].cotan[1].detach().clone())

    silent = dm.DeformationModules.SilentLandmarks(2, silent_pos.shape[0], gd=silent_pos, cotan=silent_mom)
    right_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, gd=right_ear_trans_gd, cotan=right_ear_trans_mom)
    left_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, gd=left_ear_trans_gd, cotan=left_ear_trans_mom)
    global_trans = dm.DeformationModules.GlobalTranslation(2)

    h = dm.HamiltonianDynamic.Hamiltonian([silent, right_ear_trans, left_ear_trans, global_trans])
    intermediate_states, _ = dm.HamiltonianDynamic.shoot(h, shoot_it, shoot_method, intermediates=True)

    for j in range(snapshots):
        pos = intermediate_states[it_per_snapshot*j].gd[0].numpy()
        pos_right_ear_trans = intermediate_states[it_per_snapshot*j].gd[1][0].numpy()
        pos_left_ear_trans = intermediate_states[it_per_snapshot*j].gd[2][0].numpy()

        plt.subplot(len(targets), snapshots + 1, i*snapshots + j + i + 1)
        plt.plot(pos[:, 0], pos[:, 1])
        plt.plot(pos_right_ear_trans[:, 0], pos_right_ear_trans[:, 1], 'x')
        plt.plot(pos_left_ear_trans[:, 0], pos_left_ear_trans[:, 1], 'x') 
        plt.axis('equal')


    final_pos = intermediate_states[-1].gd[0].numpy()
    plt.subplot(len(targets), snapshots + 1, i*snapshots + snapshots + i + 1)
    plt.plot(targets[i].numpy()[:, 0], targets[i].numpy()[:, 1])
    plt.plot(final_pos[:, 0], final_pos[:, 1])
    plt.axis('equal')

plt.show()