In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)
dm.Utilities.set_compute_backend('torch')
device = 'cpu'

In [2]:
torch.manual_seed(1337)
data = pickle.load(open("herd.pickle", 'rb'))

sigma_ear_tip_pos = 0.
sigma_ear_tip_dir = 0.

right_ear_tip_pos = data[0][1][0] + sigma_ear_tip_pos*torch.randn(1).item()
left_ear_tip_pos = data[0][1][1] + sigma_ear_tip_pos*torch.randn(1).item()
right_ear_tip_dir = right_ear_tip_pos/torch.norm(right_ear_tip_pos) + sigma_ear_tip_dir*torch.randn(1).item()
left_ear_tip_dir = left_ear_tip_pos/torch.norm(left_ear_tip_pos) + sigma_ear_tip_dir*torch.randn(1).item()

template = data[0][0]
template = dm.Utilities.gaussian_kernel_smooth(template, 0.1)
herd = list(list(zip(*data))[0])[1:3]

deformable_template = dm.Models.DeformablePoints(template)
deformable_herd = [dm.Models.DeformablePoints(bunny) for bunny in herd]

print(len(herd))

ear_sigma = 0.3

2


In [3]:
%matplotlib qt5
#plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '--', color='black', lw=4.)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), '+', color='black', lw=4.)
for bunny in herd:
    plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '-', color='C4')
    plt.plot(bunny.numpy()[:, 0], bunny.numpy()[:, 1], '*', color='C4')
# plt.plot(ear_pos_mean[0][0].numpy(), ear_pos_mean[0][1].numpy(), 'X')
# plt.plot(ear_pos_mean[1][0].numpy(), ear_pos_mean[1][1].numpy(), 'X')
plt.quiver(right_ear_tip_pos[0].numpy(), right_ear_tip_pos[1].numpy(), right_ear_tip_dir[0].numpy(), right_ear_tip_dir[1].numpy(), color='red')
plt.quiver(left_ear_tip_pos[0], left_ear_tip_pos[1], left_ear_tip_dir[0], left_ear_tip_dir[1], color='blue')
plt.axis('equal')
plt.show()


In [19]:
sigma_rotation = 15.
coeff_oriented = 1e-2
coeff_rotation = 1e-4
coeff_translation = 1e-1
right_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, transport='vector', coeff=coeff_oriented, gd=(right_ear_tip_pos.clone().unsqueeze(0).requires_grad_(), right_ear_tip_dir.clone().unsqueeze(0).requires_grad_()), label='right_ear_translation')
left_ear_translation = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, transport='vector', coeff=coeff_oriented, gd=(left_ear_tip_pos.clone().unsqueeze(0).requires_grad_(), left_ear_tip_dir.clone().unsqueeze(0).requires_grad_()), label='left_ear_translation')
local_rotation = dm.DeformationModules.LocalRotation(2, sigma_rotation, coeff=coeff_rotation, gd=torch.tensor([[0., 0.]], requires_grad=True))
rigid_translation = dm.DeformationModules.GlobalTranslation(2, coeff_translation)

attachment = dm.Attachment.VarifoldAttachment(2, [0.4, 2., 8.], backend='torch')
targets = herd

In [20]:
right_ear_translation.to_(device)
left_ear_translation.to_(device)
local_rotation.to_(device)
rigid_translation.to_(device)

In [21]:
atlas = dm.Models.AtlasModel(deformable_template, [right_ear_translation, left_ear_translation, local_rotation, rigid_translation], [attachment], len(targets), fit_gd=None, optimise_template=False, ht_sigma=0.1, ht_coeff=10., lam=100.)

In [22]:
shoot_solver = 'rk4'
shoot_it = 10

In [23]:
costs = {}
fitter_gd = dm.Models.Fitter(atlas, optimizer='gd')
fitter_gd.fit(deformable_herd, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'verbose': True})

fitter_lbfgs = dm.Models.Fitter(atlas, optimizer='torch_lbfgs')
fitter_lbfgs.fit(deformable_herd, 50, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe'})


SyntaxError: invalid syntax (<ipython-input-23-af91bd387c9d>, line 3)

In [19]:
%matplotlib qt5
plt.plot(range(len(costs)), tuple(zip(*costs))[1], color='black', lw=0.5)
plt.xlabel("It")
plt.ylabel("Cost")
plt.grid()
plt.show()

In [14]:
# Print recap
print(atlas)
print("")
print("Fit informations")
print("================")
print("Iteration count={it_count}".format(it_count=len(costs)))
print("Start cost={cost}".format(cost=costs[0][2]))
print("  Attach cost={cost}".format(cost=costs[0][1]))
print("  Def cost={cost}".format(cost=costs[0][0]))
print("Final cost={cost}".format(cost=costs[-1][2]))
print("  Attach cost={cost}".format(cost=costs[-1][1]))
print("  Def cost={cost}".format(cost=costs[-1][0]))
print("Integration scheme={scheme}".format(scheme=shoot_method))
print("Integration steps={steps}".format(steps=shoot_it))

Atlas
=====
Template nb pts=107
Population count=5
Module count=4
Hypertemplate=False
Attachment=VarifoldAttachment2D_Torch (weight=1.0)
  Sigmas=[4.0]
Lambda=100.0
Fit geometrical descriptors=None
Precompute callback=False
Model precompute callback=False
Other parameters=False

Modules
Oriented translation
  Label=right_ear_translation
  Sigma=0.3
  Coeff=0.01
  Nb pts=1
Oriented translation
  Label=left_ear_translation
  Sigma=0.3
  Coeff=0.01
  Nb pts=1
Local constrained translation module
  Type=Local rotation
  Sigma=15.0
  Coeff=0.0001
Global translation
  Coeff=0.1

Fit informations
Iteration count=30
Start cost=7797.1170654296875
  Attach cost=7797.054382324219
  Def cost=0.06264475124771707
Final cost=376.69342136383057
  Attach cost=375.8628845214844
  Def cost=0.8305222988128662
Integration scheme=rk4
Integration steps=10


In [9]:
print("Initial right ear tip direction: {dir}".format(dir=right_ear_tip_dir.tolist()))
print("Optimised right ear tip direction: {dir}".format(dir=atlas.registration_models[0].init_manifold[1].gd[1].detach().flatten().tolist()))
print("Initial left ear tip direction: {dir}".format(dir=left_ear_tip_dir.tolist()))
print("Optimised left ear tip direction: {dir}".format(dir=atlas.registration_models[0].init_manifold[2].gd[1].detach().flatten().tolist()))

Initial right ear tip direction: [0.38524430990219116, 0.9228146076202393]
Optimised right ear tip direction: [0.38524430990219116, 0.9228146076202393]
Initial left ear tip direction: [-0.6430988907814026, 0.7657831907272339]
Optimised left ear tip direction: [-0.6430988907814026, 0.7657831907272339]


In [10]:
print("Optimised rotation center: {center}".format(center=atlas.registration_models[0].init_manifold[3].gd.tolist()[0]))

Optimised rotation center: [0.0, 0.0]


In [11]:
# Compute optimised template
optimised_template = atlas.compute_template()[0]
optimised_template = optimised_template.cpu()

In [12]:
###############################################################################
# Display the atlas.
#

intermediates = {}
with torch.autograd.no_grad():
    deformed_templates = atlas.compute_deformed(shoot_solver, shoot_it, intermediates=intermediates)

row_count = math.ceil(math.sqrt(len(herd)))

for i, deformed, bunny in zip(range(len(herd)), deformed_templates, herd):
    plt.subplot(row_count, row_count, 1 + i)
    plt.plot(deformed[0].detach()[:, 0].numpy(), deformed[0].detach()[:, 1].numpy())
    # plt.quiver(deformed[0].detach()[:, 0].numpy(), deformed[0].detach()[:, 1].numpy(),
               # deformed[0].grad[:, 0].numpy(), deformed[0].grad[:, 1].numpy())
    plt.plot(bunny[:, 0].numpy(), bunny[:, 1].numpy())
    plt.axis('equal')

plt.show()


In [34]:
template_right_ear_pos, template_right_ear_dir = atlas.registration_models[0].init_manifold[1].gd[0].detach().cpu().flatten(), atlas.registration_models[0].init_manifold[1].gd[1].detach().cpu().flatten()
template_left_ear_pos, template_left_ear_dir = atlas.registration_models[0].init_manifold[2].gd[0].detach().cpu().flatten(), atlas.registration_models[0].init_manifold[2].gd[1].detach().cpu().flatten()

plt.plot(optimised_template[:, 0].numpy(), optimised_template[:, 1].numpy(), '-', color='grey', lw=1.5)
plt.plot(template[:, 0].numpy(), template[:, 1].numpy(), color='black', lw=0.8)

# Plot initial positions and directions of the oriented translations
plt.quiver(right_ear_tip_pos[0].numpy(), right_ear_tip_pos[1].numpy(),
           right_ear_tip_dir[0].numpy(), right_ear_tip_dir[1].numpy(), scale=10.)
plt.quiver(left_ear_tip_pos[0].numpy(), left_ear_tip_pos[1].numpy(),
           left_ear_tip_dir[0].numpy(), left_ear_tip_dir[1].numpy(), scale=10.)

# Plot optimised positions and directions of the oriented translations
plt.quiver(template_right_ear_pos[0].numpy(), template_right_ear_pos[1].numpy(),
           template_right_ear_dir[0].numpy(), template_right_ear_dir[1].numpy(), scale=10.)
plt.quiver(template_left_ear_pos[0].numpy(), template_left_ear_pos[1].numpy(),
           template_left_ear_dir[0].numpy(), template_left_ear_dir[1].numpy(), scale=10.)

# Plot position correspondance
plt.arrow(right_ear_tip_pos[0], right_ear_tip_pos[1], template_right_ear_pos[0] - right_ear_tip_pos[0], template_right_ear_pos[1] - right_ear_tip_pos[1], width=0.01, length_includes_head=True, head_width=0.08)
plt.arrow(left_ear_tip_pos[0], left_ear_tip_pos[1], template_left_ear_pos[0] - left_ear_tip_pos[0], template_left_ear_pos[1] - left_ear_tip_pos[1], width=0.01, length_includes_head=True, head_width=0.08)
plt.axis('equal')
plt.show()

In [33]:
%matplotlib qt5
it_per_snapshot = 2
snapshots = int(shoot_it/it_per_snapshot)

disp_targets = len(targets)

for i in range(disp_targets):
    silent_pos = optimised_template.clone()
    silent_mom = atlas.registration_models[i].init_manifold[0].cotan.detach().cpu().clone()
    right_ear_trans_gd = (atlas.registration_models[i].init_manifold[1].gd[0].detach().cpu().clone(),
                          atlas.registration_models[i].init_manifold[1].gd[1].detach().cpu().clone())
    right_ear_trans_mom = (atlas.registration_models[i].init_manifold[1].cotan[0].detach().cpu().clone(),
                           atlas.registration_models[i].init_manifold[1].cotan[1].detach().cpu().clone())
    left_ear_trans_gd = (atlas.registration_models[i].init_manifold[2].gd[0].detach().cpu().clone(),
                          atlas.registration_models[i].init_manifold[2].gd[1].detach().cpu().clone())
    left_ear_trans_mom = (atlas.registration_models[i].init_manifold[2].cotan[0].detach().cpu().clone(),
                          atlas.registration_models[i].init_manifold[2].cotan[1].detach().cpu().clone())
    local_rotation_gd = atlas.registration_models[i].init_manifold[3].gd.detach().cpu().clone()
    local_rotation_mom = atlas.registration_models[i].init_manifold[3].cotan.detach().cpu().clone()

    silent = dm.DeformationModules.SilentLandmarks(2, silent_pos.shape[0], gd=silent_pos, cotan=silent_mom)
    right_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, coeff=coeff_oriented, gd=right_ear_trans_gd, cotan=right_ear_trans_mom)
    left_ear_trans = dm.DeformationModules.OrientedTranslations(2, 1, ear_sigma, coeff=coeff_oriented, gd=left_ear_trans_gd, cotan=left_ear_trans_mom)
    local_rot = dm.DeformationModules.LocalRotation(2, sigma_rotation, coeff=coeff_rotation, gd=local_rotation_gd, cotan=local_rotation_mom)
    rigid_translation = dm.DeformationModules.GlobalTranslation(2, coeff=coeff_translation)

    h = dm.HamiltonianDynamic.Hamiltonian([silent, right_ear_trans, left_ear_trans, local_rot, rigid_translation])
    intermediate_states, _ = dm.HamiltonianDynamic.shoot(h, shoot_it, shoot_method, intermediates=True)


    for j in range(snapshots):
        pos = intermediate_states[it_per_snapshot*j].gd[0].numpy()
        pos_right_ear_trans = intermediate_states[it_per_snapshot*j].gd[1][0].flatten().numpy()
        pos_left_ear_trans = intermediate_states[it_per_snapshot*j].gd[2][0].flatten().numpy()
        dir_right_ear_trans = intermediate_states[it_per_snapshot*j].gd[1][1].flatten().numpy()
        dir_left_ear_trans = intermediate_states[it_per_snapshot*j].gd[2][1].flatten().numpy()

        plt.subplot(disp_targets, snapshots + 1, i*snapshots + j + i + 1)
        plt.plot(pos[:, 0], pos[:, 1], color='black')
        plt.plot(pos_right_ear_trans[0], pos_right_ear_trans[1], 'x', color='red')
        plt.plot(pos_left_ear_trans[0], pos_left_ear_trans[1], 'x', color='red')
        plt.plot(targets[i].numpy()[:, 0], targets[i].numpy()[:, 1], color='blue')
        plt.quiver(pos_right_ear_trans[0], pos_right_ear_trans[1],
                   dir_right_ear_trans[0], dir_right_ear_trans[1], scale=10.)
        plt.quiver(pos_left_ear_trans[0], pos_left_ear_trans[1],
                   dir_left_ear_trans[0], dir_left_ear_trans[1], scale=10.)

        plt.axis('equal')

    final_pos = intermediate_states[-1].gd[0].numpy()
    plt.subplot(disp_targets, snapshots + 1, i*snapshots + snapshots + i + 1)
    plt.plot(targets[i].numpy()[:, 0], targets[i].numpy()[:, 1], color='blue')
    plt.plot(final_pos[:, 0], final_pos[:, 1], color='black')
    plt.axis('equal')

    print("Target {i}: attachment={attachment}".format(i=i, attachment=attachment(targets[i], torch.tensor(final_pos))))

plt.show()

AssertionError: 