In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)
dm.Utilities.set_compute_backend('torch')
device = 'cpu'

In [3]:
torch.manual_seed(1337)
data = pickle.load(open("herd.pickle", 'rb'))

# ear_pos_mean = list(torch.mean(torch.stack(list(zip(*list(zip(*data))[1]))[i], dim=0), dim=0) for i in range(2))
# ear_pos_stdvar = list(torch.sqrt(torch.mean(torch.var(torch.stack(list(zip(*list(zip(*data))[1]))[i], dim=0), dim=0))) for i in range(2))

sigma_ear_tip_pos = 0.
sigma_ear_tip_dir = 0.

right_ear_tip_pos = data[0][1][0] + sigma_ear_tip_pos*torch.randn(1).item()
left_ear_tip_pos = data[0][1][1] + sigma_ear_tip_pos*torch.randn(1).item()
right_ear_tip_dir = right_ear_tip_pos/torch.norm(right_ear_tip_pos) + sigma_ear_tip_dir*torch.randn(1).item()
left_ear_tip_dir = left_ear_tip_pos/torch.norm(left_ear_tip_pos) + sigma_ear_tip_dir*torch.randn(1).item()

# We need a better template than that
template = data[0][0]
template = dm.Utilities.gaussian_kernel_smooth(template, 0.1)
idx = 4
#herd = list(list(zip(*data))[0])[idx:idx+1]
herd = list(list(zip(*data))[0])[1:6]


print(len(herd))

ear_sigma = 0.3

5


In [4]:
%matplotlib qt5
#plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='black', lw=4.)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '+', color='black', lw=4.)
for bunny in herd:
    plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '-', color='C4')
    plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '*', color='C4')
# plt.plot(ear_pos_mean[0][0].numpy(), ear_pos_mean[0][1].numpy(), 'X')
# plt.plot(ear_pos_mean[1][0].numpy(), ear_pos_mean[1][1].numpy(), 'X')
plt.quiver(right_ear_tip_pos[0].numpy(), right_ear_tip_pos[1].numpy(), right_ear_tip_dir[0].numpy(), right_ear_tip_dir[1].numpy(), color='red')
plt.quiver(left_ear_tip_pos[0], left_ear_tip_pos[1], left_ear_tip_dir[0], left_ear_tip_dir[1], color='blue')
plt.axis('equal')
plt.show()


In [5]:
sigma_rotation = 15.
coeff_oriented = 1e-2
coeff_rotation = 1e-4
coeff_translation = 1e-1
right_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, transport='vector', coeff=coeff_oriented, gd=(right_ear_tip_pos.clone().unsqueeze(0).requires_grad_(), right_ear_tip_dir.clone().unsqueeze(0).requires_grad_()), label='right_ear_translation')
left_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, transport='vector', coeff=coeff_oriented, gd=(left_ear_tip_pos.clone().unsqueeze(0).requires_grad_(), left_ear_tip_dir.clone().unsqueeze(0).requires_grad_()), label='left_ear_translation')
local_rotation = dm.DeformationModules.LocalRotation(2, sigma_rotation, coeff=coeff_rotation, gd=torch.tensor([[0., 0.]], requires_grad=True))
rigid_translation = dm.DeformationModules.GlobalTranslation(2, coeff_translation)

attachment = dm.Attachment.VarifoldAttachment(2, [4.], backend='torch')
targets = herd

In [6]:
right_ear_translation.to_(device)
left_ear_translation.to_(device)
local_rotation.to_(device)
rigid_translation.to_(device)

In [7]:
atlas = dm.Models.Atlas(template.clone().to(device=device), [right_ear_translation, left_ear_translation, local_rotation, rigid_translation], [attachment], len(targets), fit_gd=None, optimise_template=False, ht_sigma=0.1, ht_coeff=10., lam=100.)

In [8]:
shoot_method = 'rk4'
shoot_it = 10

In [9]:
fitter = dm.Models.ModelFittingScipy(atlas)
#fitter = dm.Models.ModelFittingGradientDescent(atlas, 1e-6)
costs = fitter.fit([target.clone().to(device=device) for target in targets], 30, options={'shoot_method': shoot_method, 'shoot_it': shoot_it})

Initial energy = 12133.427


Time: 455.1054595459718
Iteration: 1 
Total energy = 7797.1170654296875 
Attach cost = 7797.054382324219 
Deformation cost = 0.06264475124771707


Time: 499.33288658503443
Iteration: 2 
Total energy = 6815.008361816406 
Attach cost = 6814.840087890625 
Deformation cost = 0.1683010500855744


Time: 541.5488243190339
Iteration: 3 
Total energy = 4858.4876708984375 
Attach cost = 4858.32666015625 
Deformation cost = 0.16105230525135994


Time: 584.8166736649582
Iteration: 4 
Total energy = 3937.5992431640625 
Attach cost = 3937.4237670898438 
Deformation cost = 0.17544239945709705


Time: 625.464414609014
Iteration: 5 
Total energy = 3592.122344970703 
Attach cost = 3591.9815368652344 
Deformation cost = 0.1407538536004722


Time: 665.5624688040698
Iteration: 6 
Total energy = 3295.8382415771484 
Attach cost = 3295.6558532714844 
Deformation cost = 0.18240362871438265


Time: 706.5873696730705
Iteration: 7 
Total energy = 3098.2706604003906 
Attach cost = 3098.0774536132812 
Deformation cost = 0.19322240725159645


Time: 747.5706295539858
Iteration: 8 
Total energy = 2774.182815551758 
Attach cost = 2773.9592895507812 
Deformation cost = 0.22355254832655191


Time: 788.9144839890068
Iteration: 9 
Total energy = 2487.418914794922 
Attach cost = 2487.171142578125 
Deformation cost = 0.24781416356563568


Time: 831.6794243879849
Iteration: 10 
Total energy = 1667.777572631836 
Attach cost = 1667.4095153808594 
Deformation cost = 0.36809711158275604


Time: 873.5324785100529
Iteration: 11 
Total energy = 1276.6454620361328 
Attach cost = 1276.0841369628906 
Deformation cost = 0.56130800396204


Time: 917.4223027139669
Iteration: 12 
Total energy = 871.2779502868652 
Attach cost = 870.6657409667969 
Deformation cost = 0.6121805533766747


Time: 960.0317562359851
Iteration: 13 
Total energy = 749.665958404541 
Attach cost = 748.9799499511719 
Deformation cost = 0.6860087811946869


Time: 1003.2070664230268
Iteration: 14 
Total energy = 645.3773803710938 
Attach cost = 644.6266174316406 
Deformation cost = 0.750745102763176


Time: 1044.314694953966
Iteration: 15 
Total energy = 535.4862298965454 
Attach cost = 534.6771240234375 
Deformation cost = 0.8091173842549324


Time: 1085.3439060050296
Iteration: 16 
Total energy = 505.98374032974243 
Attach cost = 505.1727294921875 
Deformation cost = 0.8110020682215691


Time: 1128.8658771910705
Iteration: 17 
Total energy = 467.7916316986084 
Attach cost = 466.99066162109375 
Deformation cost = 0.8009741604328156


Time: 1171.736308563035
Iteration: 18 
Total energy = 458.587438583374 
Attach cost = 457.8056335449219 
Deformation cost = 0.7817901596426964


Time: 1221.1299123690696
Iteration: 19 
Total energy = 437.59773302078247 
Attach cost = 436.8278503417969 
Deformation cost = 0.7698923051357269


Time: 1271.4807433380047
Iteration: 20 
Total energy = 421.3923897743225 
Attach cost = 420.6214904785156 
Deformation cost = 0.7708883509039879


Time: 1315.9651905240025
Iteration: 21 
Total energy = 405.6784896850586 
Attach cost = 404.8896789550781 
Deformation cost = 0.7888107523322105


Time: 1358.4565771080088
Iteration: 22 
Total energy = 395.44306230545044 
Attach cost = 394.6311950683594 
Deformation cost = 0.8118790686130524


Time: 1399.1795114220586
Iteration: 23 
Total energy = 389.7890009880066 
Attach cost = 388.9640808105469 
Deformation cost = 0.8249124810099602


Time: 1446.0456208949909
Iteration: 24 
Total energy = 386.804297208786 
Attach cost = 385.97869873046875 
Deformation cost = 0.8256122916936874


Time: 1487.0358001539716
Iteration: 25 
Total energy = 385.42831540107727 
Attach cost = 384.60693359375 
Deformation cost = 0.8213727474212646


Time: 1528.176120297052
Iteration: 26 
Total energy = 383.56805181503296 
Attach cost = 382.7507019042969 
Deformation cost = 0.8173646330833435


Time: 1569.4292115640128
Iteration: 27 
Total energy = 381.0067881345749 
Attach cost = 380.1910400390625 
Deformation cost = 0.8157493770122528


Time: 1609.942417990067
Iteration: 28 
Total energy = 378.06095933914185 
Attach cost = 377.239990234375 
Deformation cost = 0.8209761008620262


Time: 1649.510955822072
Iteration: 29 
Total energy = 376.9082564711571 
Attach cost = 376.07879638671875 
Deformation cost = 0.8294700533151627


Time: 1690.544258306967
Iteration: 30 
Total energy = 376.69342136383057 
Attach cost = 375.8628845214844 
Deformation cost = 0.8305222988128662
Optimisation process exited with message: b'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
Final energy = 376.69342136383057
Closure evaluations = 40
Time elapsed = 1690.5452673999825


In [24]:
%matplotlib qt5
plt.plot(range(len(costs)), tuple(zip(*costs))[1], color='black', lw=0.5)
plt.xlabel("It")
plt.ylabel("Cost")
plt.grid()
plt.show()

In [14]:
# Print recap
print(atlas)
print("")
print("Fit informations")
print("================")
print("Iteration count={it_count}".format(it_count=len(costs)))
print("Start cost={cost}".format(cost=costs[0][2]))
print("  Attach cost={cost}".format(cost=costs[0][1]))
print("  Def cost={cost}".format(cost=costs[0][0]))
print("Final cost={cost}".format(cost=costs[-1][2]))
print("  Attach cost={cost}".format(cost=costs[-1][1]))
print("  Def cost={cost}".format(cost=costs[-1][0]))
print("Integration scheme={scheme}".format(scheme=shoot_method))
print("Integration steps={steps}".format(steps=shoot_it))

Atlas
=====
Template nb pts=107
Population count=5
Module count=4
Hypertemplate=False
Attachment=VarifoldAttachment2D_Torch (weight=1.0)
  Sigmas=[4.0]
Lambda=100.0
Fit geometrical descriptors=None
Precompute callback=False
Model precompute callback=False
Other parameters=False

Modules
Oriented translation
  Label=right_ear_translation
  Sigma=0.3
  Coeff=0.01
  Nb pts=1
Oriented translation
  Label=left_ear_translation
  Sigma=0.3
  Coeff=0.01
  Nb pts=1
Local constrained translation module
  Type=Local rotation
  Sigma=15.0
  Coeff=0.0001
Global translation
  Coeff=0.1

Fit informations
Iteration count=30
Start cost=7797.1170654296875
  Attach cost=7797.054382324219
  Def cost=0.06264475124771707
Final cost=376.69342136383057
  Attach cost=375.8628845214844
  Def cost=0.8305222988128662
Integration scheme=rk4
Integration steps=10


In [None]:
print("Initial right ear tip direction: {dir}".format(dir=right_ear_tip_dir.tolist()))
print("Optimised right ear tip direction: {dir}".format(dir=atlas.models[0].init_manifold[1].gd[1].detach().flatten().tolist()))
print("Initial left ear tip direction: {dir}".format(dir=left_ear_tip_dir.tolist()))
print("Optimised left ear tip direction: {dir}".format(dir=atlas.models[0].init_manifold[2].gd[1].detach().flatten().tolist()))

In [None]:
print("Optimised rotation center: {center}".format(center=atlas.models[0].init_manifold[3].gd.tolist()[0]))

In [10]:
# Compute optimised template
optimised_template, _ = atlas.compute_template()
optimised_template = optimised_template.cpu()

In [26]:
template_right_ear_pos, template_right_ear_dir = atlas.models[0].init_manifold[1].gd[0].detach().cpu().flatten(), atlas.models[0].init_manifold[1].gd[1].detach().cpu().flatten()
template_left_ear_pos, template_left_ear_dir = atlas.models[0].init_manifold[2].gd[0].detach().cpu().flatten(), atlas.models[0].init_manifold[2].gd[1].detach().cpu().flatten()

plt.plot(optimised_template[:, 0].numpy(), optimised_template[:, 1].numpy(), '-', color='grey', lw=1.5)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), color='black', lw=0.8)

# Plot initial positions and directions of the oriented translations
plt.quiver(right_ear_tip_pos[0].numpy(), right_ear_tip_pos[1].numpy(),
           right_ear_tip_dir[0].numpy(), right_ear_tip_dir[1].numpy(), scale=10.)
plt.quiver(left_ear_tip_pos[0].numpy(), left_ear_tip_pos[1].numpy(),
           left_ear_tip_dir[0].numpy(), left_ear_tip_dir[1].numpy(), scale=10.)

# Plot optimised positions and directions of the oriented translations
plt.quiver(template_right_ear_pos[0].numpy(), template_right_ear_pos[1].numpy(),
           template_right_ear_dir[0].numpy(), template_right_ear_dir[1].numpy(), scale=10.)
plt.quiver(template_left_ear_pos[0].numpy(), template_left_ear_pos[1].numpy(),
           template_left_ear_dir[0].numpy(), template_left_ear_dir[1].numpy(), scale=10.)

# Plot position correspondance
plt.arrow(right_ear_tip_pos[0], right_ear_tip_pos[1], template_right_ear_pos[0] - right_ear_tip_pos[0], template_right_ear_pos[1] - right_ear_tip_pos[1], width=0.01, length_includes_head=True, head_width=0.08)
plt.arrow(left_ear_tip_pos[0], left_ear_tip_pos[1], template_left_ear_pos[0] - left_ear_tip_pos[0], template_left_ear_pos[1] - left_ear_tip_pos[1], width=0.01, length_includes_head=True, head_width=0.08)
plt.axis('equal')
plt.show()

In [13]:
%matplotlib qt5
it_per_snapshot = 2
snapshots = int(shoot_it/it_per_snapshot)

disp_targets = len(targets)

for i in range(disp_targets):
    silent_pos = optimised_template.clone()
    silent_mom = atlas.models[i].init_manifold[0].cotan.detach().cpu().clone()
    right_ear_trans_gd = (atlas.models[i].init_manifold[1].gd[0].detach().cpu().clone(),
                          atlas.models[i].init_manifold[1].gd[1].detach().cpu().clone())
    right_ear_trans_mom = (atlas.models[i].init_manifold[1].cotan[0].detach().cpu().clone(),
                           atlas.models[i].init_manifold[1].cotan[1].detach().cpu().clone())
    left_ear_trans_gd = (atlas.models[i].init_manifold[2].gd[0].detach().cpu().clone(),
                          atlas.models[i].init_manifold[2].gd[1].detach().cpu().clone())
    left_ear_trans_mom = (atlas.models[i].init_manifold[2].cotan[0].detach().cpu().clone(),
                          atlas.models[i].init_manifold[2].cotan[1].detach().cpu().clone())
    local_rotation_gd = atlas.models[i].init_manifold[3].gd.detach().cpu().clone()
    local_rotation_mom = atlas.models[i].init_manifold[3].cotan.detach().cpu().clone()

    silent = dm.DeformationModules.SilentLandmarks(2, silent_pos.shape[0], gd=silent_pos, cotan=silent_mom)
    right_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, coeff=coeff_oriented, gd=right_ear_trans_gd, cotan=right_ear_trans_mom)
    left_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, coeff=coeff_oriented, gd=left_ear_trans_gd, cotan=left_ear_trans_mom)
    local_rot = dm.DeformationModules.LocalRotation(2, sigma_rotation, coeff=coeff_rotation, gd=local_rotation_gd, cotan=local_rotation_mom)
    rigid_translation = dm.DeformationModules.GlobalTranslation(2, coeff=coeff_translation)

    h = dm.HamiltonianDynamic.Hamiltonian([silent, right_ear_trans, left_ear_trans, local_rot, rigid_translation])
    intermediate_states, _ = dm.HamiltonianDynamic.shoot(h, shoot_it, shoot_method, intermediates=True)


    for j in range(snapshots):
        pos = intermediate_states[it_per_snapshot*j].gd[0].numpy()
        pos_right_ear_trans = intermediate_states[it_per_snapshot*j].gd[1][0].flatten().numpy()
        pos_left_ear_trans = intermediate_states[it_per_snapshot*j].gd[2][0].flatten().numpy()
        dir_right_ear_trans = intermediate_states[it_per_snapshot*j].gd[1][1].flatten().numpy()
        dir_left_ear_trans = intermediate_states[it_per_snapshot*j].gd[2][1].flatten().numpy()

        plt.subplot(disp_targets, snapshots + 1, i*snapshots + j + i + 1)
        plt.plot(pos[:, 0], pos[:, 1], color='black')
        plt.plot(pos_right_ear_trans[0], pos_right_ear_trans[1], 'x', color='red')
        plt.plot(pos_left_ear_trans[0], pos_left_ear_trans[1], 'x', color='red')
        plt.plot(targets[i].numpy()[:, 0], targets[i].numpy()[:, 1], color='blue')
        plt.quiver(pos_right_ear_trans[0], pos_right_ear_trans[1],
                   dir_right_ear_trans[0], dir_right_ear_trans[1], scale=10.)
        plt.quiver(pos_left_ear_trans[0], pos_left_ear_trans[1],
                   dir_left_ear_trans[0], dir_left_ear_trans[1], scale=10.)

        plt.axis('equal')

    final_pos = intermediate_states[-1].gd[0].numpy()
    plt.subplot(disp_targets, snapshots + 1, i*snapshots + snapshots + i + 1)
    plt.plot(targets[i].numpy()[:, 0], targets[i].numpy()[:, 1], color='blue')
    plt.plot(final_pos[:, 0], final_pos[:, 1], color='black')
    plt.axis('equal')

    print("Target {i}: attachment={attachment}".format(i=i, attachment=attachment(targets[i], torch.tensor(final_pos))))

plt.show()

Target 0: attachment=0.0128936767578125


Target 1: attachment=0.0106658935546875


Target 2: attachment=0.01139068603515625


Target 3: attachment=0.0084075927734375


Target 4: attachment=3.715301513671875
