In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)
dm.Utilities.set_compute_backend('torch')

In [2]:
pwd

'/home/gris/algos/implicitmodules/script/torch'

In [3]:
torch.manual_seed(1337)
data = pickle.load(open("../../data/herd.pickle", 'rb'))

ear_pos_mean = list(torch.mean(torch.stack(list(zip(*list(zip(*data))[1]))[i], dim=0), dim=0) for i in range(2))
ear_pos_stdvar = list(torch.sqrt(torch.mean(torch.var(torch.stack(list(zip(*list(zip(*data))[1]))[i], dim=0), dim=0))) for i in range(2))

sigma_ear_tip_pos = 0.
sigma_ear_tip_dir = 0.

right_ear_tip_pos = data[0][1][0] + sigma_ear_tip_pos*torch.randn(1).item()
left_ear_tip_pos = data[0][1][1] + sigma_ear_tip_pos*torch.randn(1).item()
right_ear_tip_dir = right_ear_tip_pos/torch.norm(right_ear_tip_pos) + sigma_ear_tip_dir*torch.randn(1).item()
left_ear_tip_dir = left_ear_tip_pos/torch.norm(left_ear_tip_pos) + sigma_ear_tip_dir*torch.randn(1).item()

# We need a better template than that
template = data[0][0]
template = dm.Utilities.gaussian_kernel_smooth(template, 0.1)
herd = list(list(zip(*data))[0])[1:8]

print(len(herd))

ear_sigma = 0.3

7


In [19]:
%matplotlib qt5
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='black', lw=4.)
for bunny in herd:
    plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '-', color='C4')
# plt.plot(ear_pos_mean[0][0].numpy(), ear_pos_mean[0][1].numpy(), 'X')
# plt.plot(ear_pos_mean[1][0].numpy(), ear_pos_mean[1][1].numpy(), 'X')
plt.quiver(right_ear_tip_pos[0].numpy(), right_ear_tip_pos[1].numpy(), right_ear_tip_dir[0].numpy(), right_ear_tip_dir[1].numpy(), color='red')
plt.quiver(left_ear_tip_pos[0], left_ear_tip_pos[1], left_ear_tip_dir[0], left_ear_tip_dir[1], color='blue')
plt.axis('equal')
plt.show()


In [5]:
%matplotlib qt5
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='black', lw=4.)
i0 = 3
j0 = 3
bunny = herd[i0]
plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '-', color='C4')
bunny = herd[j0]
plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '-', color='C4')
# plt.plot(ear_pos_mean[0][0].numpy(), ear_pos_mean[0][1].numpy(), 'X')
# plt.plot(ear_pos_mean[1][0].numpy(), ear_pos_mean[1][1].numpy(), 'X')
plt.quiver(right_ear_tip_pos[0].numpy(), right_ear_tip_pos[1].numpy(), right_ear_tip_dir[0].numpy(), right_ear_tip_dir[1].numpy(), color='red')
plt.quiver(left_ear_tip_pos[0], left_ear_tip_pos[1], left_ear_tip_dir[0], left_ear_tip_dir[1], color='blue')
plt.axis('equal')
plt.show()


In [6]:
herd_match = [herd[i0].clone().detach().requires_grad_()]#, herd[j0].clone().detach().requires_grad_()]

In [7]:
sigma_rotation = 15.
coeff_oriented = 1e0
coeff_rotation = 1e-3
coeff_translation = 1e1
right_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, transport='vector', coeff=coeff_oriented, gd=(right_ear_tip_pos.clone().unsqueeze(0).requires_grad_(), right_ear_tip_dir.clone().unsqueeze(0).requires_grad_()), label='right_ear_translation')
left_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, transport='vector', coeff=coeff_oriented, gd=(left_ear_tip_pos.clone().unsqueeze(0).requires_grad_(), left_ear_tip_dir.clone().unsqueeze(0).requires_grad_()), label='left_ear_translation')
rigid_translation = dm.DeformationModules.GlobalTranslation(2, coeff_translation)
local_rotation = dm.DeformationModules.LocalRotation(2, sigma_rotation, coeff=coeff_rotation, gd=torch.tensor([[0., 0.]], requires_grad=True))

attachment = dm.Attachment.VarifoldAttachment(2, [1., 4.])
targets = herd_match

In [8]:
atlas = dm.Models.Atlas(template, [right_ear_translation, left_ear_translation, local_rotation, rigid_translation], [attachment], len(targets), fit_gd=[False, False, False, False], optimise_template=False, ht_sigma=0.5, ht_coeff=100., lam=1000.)

In [9]:
shoot_method = 'rk4'
shoot_it = 10

In [25]:
fitter = dm.Models.ModelFittingScipy(atlas)
costs = fitter.fit(targets, 80, options={'shoot_method': shoot_method, 'shoot_it': shoot_it})

Initial energy = 86522.844
{'deformation_cost': 0.0, 'attach_cost': 86522.84375, 'cost': 86522.84375}
Time: 42.274682356000994
Iteration: 1 
Total energy = 84905.0234375 
Attach cost = 84905.0 
Deformation cost = 0.021238170564174652
Time: 58.79576369000097
Iteration: 2 
Total energy = 84104.5234375 
Attach cost = 84104.4375 
Deformation cost = 0.08259734511375427
Time: 76.02494323999963
Iteration: 3 
Total energy = 83013.09375 
Attach cost = 83012.96875 
Deformation cost = 0.12747718393802643
Time: 84.97539439499997
Iteration: 4 
Total energy = 78264.3125 
Attach cost = 78263.71875 
Deformation cost = 0.5908206105232239
Time: 93.52443664299972
Iteration: 5 
Total energy = 74809.859375 
Attach cost = 74808.5078125 
Deformation cost = 1.3536404371261597
Time: 135.8463346389999
Iteration: 6 
Total energy = 69922.15625 
Attach cost = 69918.65625 
Deformation cost = 3.4961183071136475
Time: 196.16202748700016
Iteration: 7 
Total energy = 46407.97265625 
Attach cost = 46353.64453125 
Deform

Time: 547.5484805600008
Iteration: 39 
Total energy = 25341.25 
Attach cost = 25274.7421875 
Deformation cost = 66.5076675415039
Time: 556.2799279150004
Iteration: 40 
Total energy = 25240.361328125 
Attach cost = 25172.1953125 
Deformation cost = 68.16565704345703
Time: 564.633277373001
Iteration: 41 
Total energy = 25189.8984375 
Attach cost = 25121.109375 
Deformation cost = 68.78929138183594
Time: 576.4252020740005
Iteration: 42 
Total energy = 25120.423828125 
Attach cost = 25052.306640625 
Deformation cost = 68.11730194091797
Time: 590.3130228190003
Iteration: 43 
Total energy = 24989.748046875 
Attach cost = 24924.072265625 
Deformation cost = 65.67607116699219


KeyboardInterrupt: 

In [11]:
%matplotlib qt5
plt.plot(range(len(costs)), tuple(zip(*costs))[1], color='black', lw=0.5)
plt.xlabel("It")
plt.ylabel("Cost")
plt.grid()
plt.show()

In [5]:
# Print recap
print(atlas)
print("")
print("Fit informations")
print("================")
print("Iteration count={it_count}".format(it_count=len(costs)))
print("Start cost={cost}".format(cost=costs[0][2]))
print("  Attach cost={cost}".format(cost=costs[0][1]))
print("  Def cost={cost}".format(cost=costs[0][0]))
print("Final cost={cost}".format(cost=costs[-1][2]))
print("  Attach cost={cost}".format(cost=costs[-1][1]))
print("  Def cost={cost}".format(cost=costs[-1][0]))
print("Integration scheme={scheme}".format(scheme=shoot_method))
print("Integration steps={steps}".format(steps=shoot_it))

NameError: name 'atlas' is not defined

In [None]:
print(atlas.models[0].init_manifold[1].gd)
print(atlas.models[0].init_manifold[2].gd)
print(atlas.models[0].init_manifold[3].gd)

In [None]:
print("Initial right ear tip direction: {dir}".format(dir=right_ear_tip_dir.tolist()))
print("Optimised right ear tip direction: {dir}".format(dir=atlas.models[0].init_manifold[1].gd[1].detach().flatten().tolist()))
print("Initial left ear tip direction: {dir}".format(dir=left_ear_tip_dir.tolist()))
print("Optimised left ear tip direction: {dir}".format(dir=atlas.models[0].init_manifold[2].gd[1].detach().flatten().tolist()))

In [None]:
print("Optimised rotation center: {center}".format(center=atlas.models[0].init_manifold[3].gd))

In [26]:
# Compute optimised template
optimised_template = atlas.compute_template()

In [14]:
template_right_ear_pos, template_right_ear_dir = atlas.models[0].init_manifold[1].gd[0].detach().flatten(), atlas.models[0].init_manifold[1].gd[1].detach().flatten()
template_left_ear_pos, template_left_ear_dir = atlas.models[0].init_manifold[2].gd[0].detach().flatten(), atlas.models[0].init_manifold[2].gd[1].detach().flatten()

plt.plot(optimised_template[:, 0].numpy(), optimised_template[:, 1].numpy(), '-', color='grey', lw=1.5)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), color='black', lw=0.8)

# Plot initial positions and directions of the oriented translations
plt.quiver(right_ear_tip_pos[0].numpy(), right_ear_tip_pos[1].numpy(),
           right_ear_tip_dir[0].numpy(), right_ear_tip_dir[1].numpy(), scale=10.)
plt.quiver(left_ear_tip_pos[0].numpy(), left_ear_tip_pos[1].numpy(),
           left_ear_tip_dir[0].numpy(), left_ear_tip_dir[1].numpy(), scale=10.)

# Plot optimised positions and directions of the oriented translations
plt.quiver(template_right_ear_pos[0].numpy(), template_right_ear_pos[1].numpy(),
           template_right_ear_dir[0].numpy(), template_right_ear_dir[1].numpy(), scale=10.)
plt.quiver(template_left_ear_pos[0].numpy(), template_left_ear_pos[1].numpy(),
           template_left_ear_dir[0].numpy(), template_left_ear_dir[1].numpy(), scale=10.)

# Plot position correspondance
plt.arrow(right_ear_tip_pos[0], right_ear_tip_pos[1], template_right_ear_pos[0] - right_ear_tip_pos[0], template_right_ear_pos[1] - right_ear_tip_pos[1], width=0.01, length_includes_head=True, head_width=0.08)
plt.arrow(left_ear_tip_pos[0], left_ear_tip_pos[1], template_left_ear_pos[0] - left_ear_tip_pos[0], template_left_ear_pos[1] - left_ear_tip_pos[1], width=0.01, length_includes_head=True, head_width=0.08)
plt.axis('equal')
plt.show()

ValueError: 'vertices' must be a 2D list or array with shape Nx2

In [10]:
%matplotlib qt5
it_per_snapshot = 1
snapshots = int(shoot_it/it_per_snapshot)

disp_targets = len(targets)

for i in range(disp_targets):
    #silent_pos = atlas.models[i].init_manifold[0].gd.detach().clone()
    silent_pos = optimised_template.clone()
    silent_mom = atlas.models[i].init_manifold[0].cotan.detach().clone()
    right_ear_trans_gd = (atlas.models[i].init_manifold[1].gd[0].detach().clone(),
                          atlas.models[i].init_manifold[1].gd[1].detach().clone())
    right_ear_trans_mom = (atlas.models[i].init_manifold[1].cotan[0].detach().clone(),
                           atlas.models[i].init_manifold[1].cotan[1].detach().clone())
    left_ear_trans_gd = (atlas.models[i].init_manifold[2].gd[0].detach().clone(),
                          atlas.models[i].init_manifold[2].gd[1].detach().clone())
    left_ear_trans_mom = (atlas.models[i].init_manifold[2].cotan[0].detach().clone(),
                          atlas.models[i].init_manifold[2].cotan[1].detach().clone())
    local_rotation_gd = atlas.models[i].init_manifold[3].gd.detach().clone()
    local_rotation_mom = atlas.models[i].init_manifold[3].cotan.detach().clone()

    silent = dm.DeformationModules.SilentLandmarks(2, silent_pos.shape[0], gd=silent_pos, cotan=silent_mom)
    right_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, coeff=coeff_oriented, gd=right_ear_trans_gd, cotan=right_ear_trans_mom)
    left_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, coeff=coeff_oriented, gd=left_ear_trans_gd, cotan=left_ear_trans_mom)
    local_rot = dm.DeformationModules.LocalRotation(2, sigma_rotation, coeff=coeff_rotation, gd=local_rotation_gd, cotan=local_rotation_mom)
    rigid_translation = dm.DeformationModules.GlobalTranslation(2, coeff=coeff_translation)

    h = dm.HamiltonianDynamic.Hamiltonian([silent, right_ear_trans, left_ear_trans, local_rot, rigid_translation])
    intermediate_states, intermediate_controls = dm.HamiltonianDynamic.shoot(h, shoot_it, shoot_method, intermediates=True)


    for j in range(snapshots):
        pos = intermediate_states[it_per_snapshot*j].gd[0].numpy()
        pos_right_ear_trans = intermediate_states[it_per_snapshot*j].gd[1][0].flatten().numpy()
        pos_left_ear_trans = intermediate_states[it_per_snapshot*j].gd[2][0].flatten().numpy()
        dir_right_ear_trans = intermediate_states[it_per_snapshot*j].gd[1][1].flatten().numpy()
        dir_left_ear_trans = intermediate_states[it_per_snapshot*j].gd[2][1].flatten().numpy()

        plt.subplot(disp_targets, snapshots + 1, i*snapshots + j + i + 1)
        plt.plot(pos[:, 0], pos[:, 1], color='black')
        plt.plot(pos_right_ear_trans[0], pos_right_ear_trans[1], 'x', color='red')
        plt.plot(pos_left_ear_trans[0], pos_left_ear_trans[1], 'x', color='red')
        plt.plot(targets[i].detach().numpy()[:, 0], targets[i].detach().numpy()[:, 1], color='blue')
        plt.quiver(pos_right_ear_trans[0], pos_right_ear_trans[1],
                   dir_right_ear_trans[0], dir_right_ear_trans[1], scale=10.)
        plt.quiver(pos_left_ear_trans[0], pos_left_ear_trans[1],
                   dir_left_ear_trans[0], dir_left_ear_trans[1], scale=10.)

        plt.axis('equal')

    final_pos = intermediate_states[-1].gd[0].numpy()
    plt.subplot(disp_targets, snapshots + 1, i*snapshots + snapshots + i + 1)
    plt.plot(targets[i].detach().numpy()[:, 0], targets[i].detach().numpy()[:, 1], color='blue')
    plt.plot(final_pos[:, 0], final_pos[:, 1], color='black')
    plt.axis('equal')

    print("Target {i}: attachment={attachment}".format(i=i, attachment=attachment(targets[i], torch.tensor(final_pos))))

plt.show()

NameError: name 'optimised_template' is not defined

In [15]:
intermediate_controls

[[tensor([]),
  tensor([0.8856]),
  tensor([0.8499]),
  tensor(125.7613),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8834]),
  tensor([0.8463]),
  tensor(127.9718),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8942]),
  tensor([0.8459]),
  tensor(122.0743),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8869]),
  tensor([0.8374]),
  tensor(128.9594),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8922]),
  tensor([0.8399]),
  tensor(125.7858),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8930]),
  tensor([0.8346]),
  tensor(128.0275),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8991]),
  tensor([0.8374]),
  tensor(122.0464),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8966]),
  tensor([0.8280]),
  tensor(128.9665),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8986]),
  tensor([0.8327]),
  tensor(125.7973),
  tensor([0.8324, 0.9059])],
 [tensor([]),
  tensor([0.8982]),
  tensor([0.8310]),
  tensor(128.0818),