In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import math
import copy

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

dm.Utilities.set_compute_backend('torch')

In [2]:
# WIP
# On va merge les deux .pkl en un seul dictionnary
data_source = pickle.load(open("../../data/leafbasi.pkl", 'rb'))
data_target = pickle.load(open("../../data/leafbasit.pkl", 'rb'))

Dx = 0.
Dy = 0.
height_source = 38.
height_target = 100.

source = torch.tensor(data_source[1]).type(torch.get_default_dtype())
target = torch.tensor(data_target[1]).type(torch.get_default_dtype())

smin, smax = torch.min(source[:, 1]), torch.max(source[:, 1])
sscale = height_source / (smax - smin)
source[:, 1] = Dy - sscale * (source[:, 1] - smax)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source[:, 0]))

tmin, tmax = torch.min(target[:, 1]), torch.max(target[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 1] = - tscale * (target[:, 1] - tmax)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target[:, 0]))

pos_source = source[source[:, 2] == 2, 0:2]
pos_source = torch.tensor(np.delete(pos_source.numpy(), 3, axis=0))
pos_implicit0 = source[source[:, 2] == 1, 0:2]
pos_implicit1 = source[source[:, 2] == 1, 0:2]
pos_target = target[target[:, 2] == 2, 0:2]

aabb = dm.Utilities.AABB.build_from_points(pos_target)
aabb.squared()

<implicitmodules.torch.Utilities.aabb.AABB at 0x7fd2c2a64b70>

In [3]:
# Some plots
%matplotlib qt5

plt.subplot(2, 2, 1)
plt.axis(aabb.totuple())
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')
plt.axis('equal')

plt.subplot(2, 2, 2)
plt.axis(aabb.totuple())
plt.title('Target')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.axis('equal')

plt.subplot(2, 2, 3)
plt.imshow(data_source[0])

plt.subplot(2, 2, 4)
plt.imshow(data_target[0])

plt.show()

In [4]:
# Setting up the modules

# Global translation module
global_translation = dm.DeformationModules.GlobalTranslation(2)

# Local translation module
sigma0 = 15.
nu0 = 0.001
coeff0 = 100.
implicit0 = dm.DeformationModules.ImplicitModule0(2, pos_implicit0.shape[0], sigma0, nu0, coeff0, gd=pos_implicit0.clone().requires_grad_())

# Elastic modules
sigma1 = 30.
nu1 = 0.001
coeff1 = 0.01
C = torch.zeros(pos_implicit1.shape[0], 2, 1)
K, L = 10, height_source
a, b = -2 / L ** 3, 3 / L ** 2
C[:, 1, 0] = (K * (a * (L - pos_implicit1[:, 1] + Dy) ** 3  + b * (L - pos_implicit1[:, 1] + Dy) ** 2))
C[:, 0, 0] = 1. * C[:, 1, 0]
th = 0. * math.pi * torch.ones(pos_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(t) for t in th])

implicit1 = dm.DeformationModules.ImplicitModule1(2, pos_implicit1.shape[0], sigma1, C, nu=nu1, coeff=coeff1, gd=(pos_implicit1.clone().requires_grad_(), R.clone().requires_grad_()))

In [5]:
# Setting up the model and start the fitting loop
model = dm.Models.ModelPointsRegistration([pos_source.clone()], [global_translation, implicit0, implicit1], [dm.Attachment.VarifoldAttachment(2, [20., 60.], backend='torch')], lam=100.)
fitter = dm.Models.ModelFittingScipy(model)

In [6]:
shoot_method = 'euler'
shoot_it = 10

In [7]:
costs = fitter.fit([pos_target], 40, log_interval=1, options={'shoot_method': shoot_method, 'shoot_it': shoot_it})

Initial energy = 1813580.500


Time: 4.175846422091126
Iteration: 1 
Total energy = 1170181.375 
Attach cost = 1170137.25 
Deformation cost = 44.0951042175293


Time: 7.013536680024117
Iteration: 2 
Total energy = 906520.1875 
Attach cost = 906370.3125 
Deformation cost = 149.8910369873047


Time: 8.419079469982535
Iteration: 3 
Total energy = 349651.1875 
Attach cost = 349108.0 
Deformation cost = 543.180908203125


Time: 11.170904641970992
Iteration: 4 
Total energy = 60782.1484375 
Attach cost = 60062.890625 
Deformation cost = 719.2579956054688


Time: 12.651352376909927
Iteration: 5 
Total energy = 49618.5703125 
Attach cost = 48848.046875 
Deformation cost = 770.5241088867188


Time: 14.178781975060701
Iteration: 6 
Total energy = 47663.83203125 
Attach cost = 46920.1171875 
Deformation cost = 743.7149658203125


Time: 15.612261892063543
Iteration: 7 
Total energy = 47656.6015625 
Attach cost = 46910.15625 
Deformation cost = 746.4454956054688


Time: 18.483120071003214
Iteration: 8 
Total energy = 47655.91015625 
Attach cost = 46909.5703125 
Deformation cost = 746.33984375


Time: 21.30247601890005
Iteration: 9 
Total energy = 47652.90234375 
Attach cost = 46907.03125 
Deformation cost = 745.8729248046875


Time: 22.781809905078262
Iteration: 10 
Total energy = 47644.66796875 
Attach cost = 46899.8046875 
Deformation cost = 744.864501953125


Time: 24.2295765040908
Iteration: 11 
Total energy = 47617.46875 
Attach cost = 46874.609375 
Deformation cost = 742.86083984375


Time: 25.636205493006855
Iteration: 12 
Total energy = 47550.8125 
Attach cost = 46810.9375 
Deformation cost = 739.8742065429688


Time: 27.015999217052013
Iteration: 13 
Total energy = 47365.94921875 
Attach cost = 46631.0546875 
Deformation cost = 734.8960571289062


Time: 28.486450403928757
Iteration: 14 
Total energy = 46799.01171875 
Attach cost = 46072.4609375 
Deformation cost = 726.5508422851562


Time: 32.72883751196787
Iteration: 15 
Total energy = 44275.9296875 
Attach cost = 43565.625 
Deformation cost = 710.3045654296875


Time: 35.60085224104114
Iteration: 16 
Total energy = 43674.94140625 
Attach cost = 42959.1796875 
Deformation cost = 715.7617797851562


Time: 37.02270320197567
Iteration: 17 
Total energy = 40041.05078125 
Attach cost = 39273.046875 
Deformation cost = 768.0025024414062


Time: 38.491498704999685
Iteration: 18 
Total energy = 39563.06640625 
Attach cost = 38774.0234375 
Deformation cost = 789.04443359375


Time: 39.92067113891244
Iteration: 19 
Total energy = 38988.109375 
Attach cost = 38193.9453125 
Deformation cost = 794.1624755859375


Time: 41.331283688079566
Iteration: 20 
Total energy = 38686.0703125 
Attach cost = 37890.625 
Deformation cost = 795.4434204101562


Time: 42.752117621013895
Iteration: 21 
Total energy = 37730.51953125 
Attach cost = 36935.15625 
Deformation cost = 795.3619384765625


Time: 44.2771763720084
Iteration: 22 
Total energy = 36674.59375 
Attach cost = 35883.7890625 
Deformation cost = 790.802734375


Time: 45.80205652094446
Iteration: 23 
Total energy = 34975.71484375 
Attach cost = 34162.5 
Deformation cost = 813.2137451171875


Time: 47.29344388493337
Iteration: 24 
Total energy = 33524.79296875 
Attach cost = 32686.5234375 
Deformation cost = 838.27001953125


Time: 50.16196626191959
Iteration: 25 
Total energy = 32585.82421875 
Attach cost = 31730.2734375 
Deformation cost = 855.5514526367188


Time: 53.152602853951976
Iteration: 26 
Total energy = 32272.759765625 
Attach cost = 31410.3515625 
Deformation cost = 862.40869140625


Time: 54.58075379091315
Iteration: 27 
Total energy = 31879.533203125 
Attach cost = 31023.046875 
Deformation cost = 856.4860229492188


Time: 55.99422756792046
Iteration: 28 
Total energy = 31184.666015625 
Attach cost = 30319.53125 
Deformation cost = 865.1345825195312


Time: 57.39692919002846
Iteration: 29 
Total energy = 30548.310546875 
Attach cost = 29671.6796875 
Deformation cost = 876.6310424804688


Time: 60.31729691289365
Iteration: 30 
Total energy = 30319.61328125 
Attach cost = 29427.9296875 
Deformation cost = 891.6835327148438


Time: 61.7612081610132
Iteration: 31 
Total energy = 29959.26953125 
Attach cost = 29049.0234375 
Deformation cost = 910.2451171875


Time: 63.124896321911365
Iteration: 32 
Total energy = 29597.84375 
Attach cost = 28674.8046875 
Deformation cost = 923.0383911132812


Time: 64.54627177794464
Iteration: 33 
Total energy = 29526.40625 
Attach cost = 28597.4609375 
Deformation cost = 928.9461669921875


Time: 66.0295654849615
Iteration: 34 
Total energy = 29458.775390625 
Attach cost = 28511.328125 
Deformation cost = 947.44775390625


Time: 67.49563656491227
Iteration: 35 
Total energy = 29432.845703125 
Attach cost = 28493.1640625 
Deformation cost = 939.6817626953125


Time: 68.91312305489555
Iteration: 36 
Total energy = 29427.439453125 
Attach cost = 28485.15625 
Deformation cost = 942.28271484375


Time: 70.47218667599373
Iteration: 37 
Total energy = 29424.97265625 
Attach cost = 28480.6640625 
Deformation cost = 944.3080444335938


Time: 71.97469987301156
Iteration: 38 
Total energy = 29423.185546875 
Attach cost = 28478.515625 
Deformation cost = 944.6697998046875


Time: 73.41382108000107
Iteration: 39 
Total energy = 29416.64453125 
Attach cost = 28470.3125 
Deformation cost = 946.332275390625


Time: 74.84998879302293
Iteration: 40 
Total energy = 29406.802734375 
Attach cost = 28458.59375 
Deformation cost = 948.2095947265625
Optimisation process exited with message: b'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
Final energy = 29406.802734375
Closure evaluations = 52
Time elapsed = 74.85131214605644


In [8]:
# Results
intermediate_states, _ = model.compute_deformed(shoot_method, shoot_it, intermediates=True)

deformed_source = intermediate_states[-1][0].gd
deformed_implicit0 = intermediate_states[-1][2].gd
deformed_implicit1 = intermediate_states[-1][3].gd[0]

%matplotlib qt5
plt.subplot(1, 3, 1)
plt.title("Source")
plt.axis(aabb.totuple())
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')
plt.axis('equal')

plt.subplot(1, 3, 2)
plt.title("Deformed source")
plt.axis(aabb.totuple())
plt.plot(deformed_source[:, 0], deformed_source[:, 1], '-')
plt.plot(deformed_implicit0[:, 0], deformed_implicit0[:, 1], 'x')
plt.plot(deformed_implicit1[:, 0], deformed_implicit1[:, 1], '.')
plt.axis('equal')

plt.subplot(1, 3, 3)
plt.title("Deformed source and target")
plt.axis(aabb.totuple())
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.plot(deformed_source[:, 0], deformed_source[:, 1], '-')
plt.axis('equal')
plt.show()

In [142]:
modules = dm.DeformationModules.CompoundModule(copy.copy(model.modules))
modules.manifold.fill(model.init_manifold.clone(), copy=True)

intermediate_states, intermediate_controls = dm.HamiltonianDynamic.shoot(dm.HamiltonianDynamic.Hamiltonian(modules), shoot_it, shoot_method, intermediates=True)

implicit0_controls = [control[2] for control in intermediate_controls]
implicit1_controls = [control[3] for control in intermediate_controls]

print(implicit1_controls)

[tensor([-0.0180]), tensor([0.0365]), tensor([0.0865]), tensor([0.1399]), tensor([0.2030]), tensor([0.2786]), tensor([0.3508]), tensor([0.3607]), tensor([0.3003]), tensor([0.2561])]


In [149]:
modules = dm.DeformationModules.CompoundModule(copy.copy(model.modules))
modules.manifold.fill(model.init_manifold.clone(), copy=True)

silent = copy.copy(modules[0])
deformation_grid = dm.DeformationModules.DeformationGrid(dm.Utilities.AABB.build_from_points(pos_source).scale(1.2), [32, 32])
implicit1 = copy.copy(modules[3])

# silent.manifold.fill(model.init_manifold[0], copy=True)
# silent.manifold.fill_cotan_zeros()
# implicit0.manifold.fill(model.init_manifold[3], copy=True)
# implicit1.manifold.fill_cotan_zeros()

controls = [[torch.tensor([]), torch.tensor([]), control] for control in implicit1_controls]

dm.HamiltonianDynamic.shoot(dm.HamiltonianDynamic.Hamiltonian([silent, deformation_grid, implicit1]), shoot_it, shoot_method, controls=controls)

deformed_source = silent.manifold.gd.detach()
deformed_grid = deformation_grid.togrid()
deformed_implicit1 = implicit1.manifold.gd[0].detach()

ax = plt.subplot()
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '--', color='black')
plt.plot(deformed_source[:, 0].numpy(), deformed_source[:, 1].numpy(), '-', color='black')
dm.Utilities.plot_grid(ax, deformed_grid[0], deformed_grid[1], color='xkcd:light blue', lw=0.5)
plt.plot(deformed_implicit1[:, 0].numpy(), deformed_implicit1[:, 1].numpy(), 'x')
plt.axis('equal')
plt.show()

In [155]:
modules = dm.DeformationModules.CompoundModule(copy.copy(model.modules))
modules.manifold.fill(model.init_manifold.clone(), copy=True)

silent = copy.copy(modules[0])
deformation_grid = dm.DeformationModules.DeformationGrid(dm.Utilities.AABB.build_from_points(pos_source).scale(1.2), [32, 32])
implicit0 = copy.copy(modules[2])

controls = [[torch.tensor([]), torch.tensor([]), control] for control in implicit0_controls]

dm.HamiltonianDynamic.shoot(dm.HamiltonianDynamic.Hamiltonian([silent, deformation_grid, implicit0]), shoot_it, shoot_method, controls=controls)

deformed_source = silent.manifold.gd.detach()
deformed_grid = deformation_grid.togrid()
deformed_implicit0 = implicit0.manifold.gd.detach()

ax = plt.subplot()
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '--', color='black')
plt.plot(deformed_source[:, 0].numpy(), deformed_source[:, 1].numpy(), '-', color='black')
dm.Utilities.plot_grid(ax, deformed_grid[0], deformed_grid[1], color='xkcd:light blue', lw=0.5)
plt.plot(deformed_implicit0[:, 0].numpy(), deformed_implicit0[:, 1].numpy(), 'x')
plt.axis('equal')
plt.show()