In [1]:
%load_ext autoreload
%autoreload 2

import pickle
import math
import copy

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

torch.set_default_tensor_type(torch.FloatTensor)

dm.Utilities.set_compute_backend('torch')
device = 'cpu'

In [2]:
# Loading the datasets
try:
    data = pickle.load(open("../../data/data_acropetal.pkl", 'rb'))
except:
    print("Could not load the file.")

pos_source = torch.tensor(data['source_silent']).type(torch.get_default_dtype())
pos_implicit0 = torch.tensor(data['source_implicit0']).type(torch.get_default_dtype())
pos_implicit1 = torch.tensor(data['source_implicit1']).type(torch.get_default_dtype())
pos_target = torch.tensor(data['target_silent']).type(torch.get_default_dtype())

# Some rescaling for the source
Dx = 0.
Dy = 0.
height_source = 90.
height_target = 495.

smin, smax = torch.min(pos_source[:, 1]), torch.max(pos_source[:, 1])
sscale = height_source / (smax - smin)
pos_source[:, 0] = Dx + sscale * (pos_source[:, 0] - torch.mean(pos_source[:, 0]))
pos_source[:, 1] = Dy - sscale * (pos_source[:, 1] - smax)
pos_implicit0[:, 0] = Dx + sscale * (pos_implicit0[:, 0] - torch.mean(pos_implicit0[:, 0]))
pos_implicit0[:, 1] = Dy - sscale * (pos_implicit0[:, 1] - smax)
pos_implicit1[:, 0] = Dx + sscale * (pos_implicit1[:, 0] - torch.mean(pos_implicit1[:, 0]))
pos_implicit1[:, 1] = Dy - sscale * (pos_implicit1[:, 1] - smax)

# Some rescaling for the target
tmin, tmax = torch.min(pos_target[:, 1]), torch.max(pos_target[:, 1])
tscale = height_target / (tmax - tmin)
pos_target[:, 0] = tscale * (pos_target[:, 0] - torch.mean(pos_target[:, 0]))
pos_target[:, 1] = - tscale * (pos_target[:, 1] - tmax)

# Compute an AABB for plotting
aabb = dm.Utilities.AABB.build_from_points(pos_target)
aabb.squared()

<implicitmodules.torch.Utilities.aabb.AABB at 0x7f881c6bbd30>

In [12]:
# Some plots
%matplotlib qt5

plt.subplot(2, 2, 1)
plt.title("Source")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')
plt.axis('equal')

plt.subplot(2, 2, 2)
plt.axis(aabb.totuple())
plt.title("Target")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.axis('equal')

plt.subplot(2, 2, 3)
plt.imshow(data['source_img'])

plt.subplot(2, 2, 4)
plt.imshow(data['target_img'])

plt.show()

In [5]:
# Setting up the modules

# Global translation module
global_translation = dm.DeformationModules.GlobalTranslation(2)

# Local translation module
sigma0 = 10.
nu0 = 0.001
coeff0 = 100.
implicit0 = dm.DeformationModules.ImplicitModule0(2, pos_implicit0.shape[0], sigma0, nu0, coeff0,  gd=pos_implicit0.clone().requires_grad_())

# Elastic module
sigma1 = 30.
nu1 = 0.0005
coeff1 = 0.01
C = torch.zeros(pos_implicit1.shape[0], 2, 1)
K, L = 10, height_source
a, b = 1./L, 3.
z = a*(pos_implicit1[:, 1] - Dy)
C[:, 1, 0] = K * ((1 - b) * z**2 + b * z)
C[:, 0, 0] = 0.8 * C[:, 1, 0]
th = 0. * math.pi * torch.ones(pos_implicit1.shape[0])
R = torch.stack([dm.Utilities.rot2d(t) for t in th])

implicit1 = dm.DeformationModules.ImplicitModule1(2, pos_implicit1.shape[0], sigma1, C, nu1, coeff1, gd=(pos_implicit1.clone().requires_grad_(), R.clone().requires_grad_()))

#global_translation.to(device)
#implicit0.to(device)
#implicit1.to(device)

In [None]:
# Setting up the model and start the fitting loop
model = dm.Models.ModelPointsRegistration([pos_source], [global_translation, implicit0, implicit1], [dm.Attachment.VarifoldAttachment(2, [10., 30., 80.], backend='torch')], lam=50.)

fitter = dm.Models.ModelFittingScipy(model)

In [7]:
costs = fitter.fit([pos_target], 50, log_interval=5)

Initial energy = 13788988.000


Time: 3.3283469699090347
Iteration: 1 
Total energy = 13438368.0 
Attach cost = 13438300.0 
Deformation cost = 68.0116958618164


Time: 9.822632144903764
Iteration: 5 
Total energy = 12603798.0 
Attach cost = 12603542.0 
Deformation cost = 255.52517700195312


Time: 16.697283782996237
Iteration: 10 
Total energy = 3212767.75 
Attach cost = 3190309.75 
Deformation cost = 22457.98828125


Time: 22.39452893694397
Iteration: 15 
Total energy = 2668691.0 
Attach cost = 2648644.0 
Deformation cost = 20046.9296875


Optimisation process exited with message: b'ABNORMAL_TERMINATION_IN_LNSRCH'
Final energy = 2668614.0
Closure evaluations = 72
Time elapsed = 82.76968592591584


In [11]:
# Results

modules = dm.DeformationModules.CompoundModule(copy.copy(model.modules))
print(len(modules.modules))
modules.manifold.fill(model.init_manifold)
dm.HamiltonianDynamic.shoot(dm.HamiltonianDynamic.Hamiltonian(modules), 10, 'euler')
out = modules.manifold[0].gd.detach().cpu().numpy()
shot_implicit0 = modules.manifold[2].gd.detach().cpu().numpy()
shot_implicit1 = modules.manifold[3].gd[0].detach().cpu().numpy()

%matplotlib qt5
plt.subplot(1, 3, 1)
plt.title("Source")
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')
plt.axis('equal')

plt.subplot(1, 3, 2)
plt.title("Deformed source")
plt.plot(out[:, 0], out[:, 1], '-')
plt.plot(shot_implicit0[:, 0], shot_implicit0[:, 1], 'x')
plt.plot(shot_implicit1[:, 0], shot_implicit1[:, 1], '.')
plt.axis('equal')

plt.subplot(1, 3, 3)
plt.title("Deformed source and target")
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.plot(out[:, 0], out[:, 1], '-')
plt.axis('equal')
plt.show()

4


In [None]:
# Evolution of the cost with iterations
plt.title("Cost")
plt.xlabel("Iteration(s)")
plt.ylabel("Cost")
plt.plot(range(len(costs)), costs)
plt.show()