In [None]:
%load_ext autoreload
%autoreload 2

import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt
import torch

import implicitmodules.torch as dm

In [None]:
N = 10
extend = 1.
sigma_noise = 0.001
source = torch.zeros(2*N, 2)
source[:N, 0] = torch.linspace(-extend/2., extend/2., N)
source[N:, 1] = torch.linspace(-extend/2., extend/2., N)
source = source + sigma_noise*torch.randn_like(source)

lineardef = torch.mm(dm.Utilities.rot2d(math.pi/16), 0.5*torch.tensor([[1., 0.], [0., 1.]]))
target = torch.bmm(source.view(-1, 1, 2), lineardef.repeat(2*N, 1, 1)).view(-1, 2)
target = target + sigma_noise*torch.randn_like(target)

In [None]:
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.', color='blue')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), '.', color='red')
plt.axis('equal')
plt.show()

In [None]:
lineardefmodule = dm.DeformationModules.LinearDeformation.build(torch.eye(2, requires_grad=True), gd=torch.zeros(1, 2, requires_grad=True))

model = dm.Models.ModelPointsRegistration([source.clone()], [lineardefmodule], [dm.Attachment.EuclideanPointwiseDistanceAttachment()], other_parameters=[lineardefmodule.A], lam=1000., fit_moments=True)

In [None]:
modelfitter = dm.Models.ModelFittingScipy(model, 1.)
costs = modelfitter.fit([target.clone()], 55, options={'shoot_method': 'torch_euler'})

In [None]:
deformed_source = model.modules[0].manifold.gd.detach()
control = model.modules[1].controls.detach()
cotan = model.modules[1].manifold.cotan
lindef_opti = model.modules[1].A

print(control)
print(cotan)
print(lindef_opti)

%matplotlib qt5
plt.plot(deformed_source[:, 0].numpy(), deformed_source[:, 1].numpy(), '.', color='green')
plt.plot(source[:, 0].numpy(), source[:, 1].numpy(), '.', color='blue')
plt.plot(target[:, 0].numpy(), target[:, 1].numpy(), '.', color='red')

for i in range(source.shape[0]):
    plt.plot([source[i, 0].numpy(), deformed_source[i, 0].numpy()],
             [source[i, 1].numpy(), deformed_source[i, 1].numpy()], color='black')

plt.axis('equal')
plt.show()
