In [52]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tqdm import tqdm
tqdm.pandas()
import matplotlib.pyplot as plt
import torch
import plotly.graph_objs as go
%matplotlib notebook
torchdtype = torch.float32

In [2]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

# Load

In [3]:
raw_data = pd.read_excel("../data/113.xlsx")

In [4]:
def get_point_matrix(row):
    cloud = np.zeros((30,3))
    for i in range(90):
        try:
            cloud[i//3, i%3] = float(row[f"Symm{i+1}"])
        except: pass
    return cloud

In [5]:
raw_data.apply(get_point_matrix, axis=1)

0      [[0.1935843291266556, 0.01532673792711466, 0.0...
1      [[0.1158010028844426, 0.08418297525813459, 0.0...
2      [[0.17385296856600804, 0.0014354645441368033, ...
3      [[0.15122614860708916, 0.0753608027243323, 0.0...
4      [[0.1634659508842222, 0.06426250896325487, 0.0...
                             ...                        
108    [[0.17022082705033725, 0.04975286153546156, 0....
109    [[0.2023842397456483, 0.011433544221053329, 0....
110    [[0.18087756330537128, 0.015763604236610294, 0...
111    [[0.14718340305881095, 0.04079020661380661, 0....
112    [[0.16518532528224664, 0.022433257267478756, 0...
Length: 113, dtype: object

In [6]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(*get_point_matrix(raw_data.iloc[3]).T)

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7f86050c29d0>

In [49]:
VS, FS, VT, FT = torch.load("../data/hippos.pt")

In [9]:
from pykeops.torch import Vi, Vj

In [22]:
def plots(src, tar, FS, FT):
    import os
    x, y, z = (
    src[:, 0].detach().cpu().numpy(),
    src[:, 1].detach().cpu().numpy(),
    src[:, 2].detach().cpu().numpy(),
    )
    i, j, k = (
        FS[:, 0].detach().cpu().numpy(),
        FS[:, 1].detach().cpu().numpy(),
        FS[:, 2].detach().cpu().numpy(),
    )

    xt, yt, zt = (
        tar[:, 0].detach().cpu().numpy(),
        tar[:, 1].detach().cpu().numpy(),
        tar[:, 2].detach().cpu().numpy(),
    )
    it, jt, kt = (
        FT[:, 0].detach().cpu().numpy(),
        FT[:, 1].detach().cpu().numpy(),
        FT[:, 2].detach().cpu().numpy(),
    )
    
    save_folder = "../doc/_build/html/_images/"
    os.makedirs(save_folder, exist_ok=True)


    fig = go.Figure(
        data=[
            go.Mesh3d(x=xt, y=yt, z=zt, i=it, j=jt, k=kt, color="blue", opacity=0.50),
            go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color="red", opacity=0.50),
        ]
    )
    
    fig.write_html(save_folder + "data.html", auto_open=False)

In [42]:
plots(VS+20,VT, FS,FT)

In [11]:
import torch
from torch.autograd import grad

from pykeops.torch import Vi, Vj

In [107]:
sigma = torch.tensor([20], dtype=torchdtype)

In [267]:
def GaussKernelMomentum(sigma):
    x, y, b = Vi(0, 3), Vj(1, 3), Vj(2, 3)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2 * gamma).exp()
    return (K * b).sum_reduction(axis=1)

def GaussKernel(sigma):
    x, y = Vi(0, 3), Vj(1, 3)
    gamma = 1 / (sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2 * gamma).exp()
    return K.sum_reduction(axis=1)

In [268]:
def Hamiltonian(K):
    def H(p, q):
        return 0.5 * (p * K(q, q, p)).sum()

    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)

    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp

    return HS

In [269]:
def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(
                map(
                    lambda x, xdot, xdoti: x + (0.25 * dt) * (xdot + 3 * xdoti),
                    x,
                    xdot,
                    xdoti,
                )
            )
            l.append(x)
        return l

    return f

In [270]:
def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)

    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)

    return Integrator(FlowEq, (x0, p0, q0), deltat)[0]


def LDDMMloss(K, dataloss, gamma=50):
    def loss(p0, q0):
        p, q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)

    return loss

In [271]:
def data_loss(target, kernel):
    def d(a):
        return torch.dot(a.view(-1), torch.ones_like(a).view(-1))
    
    cst = d(kernel(target, target))
    
    def loss(source):
        sker = kernel(source, source)
        ker = kernel(target,source)
        return (
            cst
            + d(sker)
            - 2*d(ker)
        )
    return loss

In [272]:
dl = data_loss(VT, GaussKernel(sigma))

In [273]:
torchdtype = torch.float32

In [274]:
q0 = VS.detach().to(dtype=torchdtype).requires_grad_(True)

In [275]:
dataloss = data_loss(VT, GaussKernel(sigma=sigma))
momkern = GaussKernelMomentum(sigma=sigma)
loss = LDDMMloss(momkern, dataloss)

In [276]:
p0 = torch.zeros(q0.shape, dtype=torchdtype, requires_grad=True)

optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10)

In [277]:
a = torch.tensor([[1,2,3], [4,5,6], [4,5,6]], dtype=torchdtype)
b = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [7,8,9]], dtype=torchdtype)

In [278]:
def closure():
    optimizer.zero_grad()
    L = loss(p0, q0)
    print("loss", L.detach().cpu().numpy())
    L.backward()
    return L

for i in range(10):
    print("it ", i, ": ", end="")
    optimizer.step(closure)

it  0 : torch.Size([6611, 3]) torch.Size([6611, 3])
loss 1899548.0
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 1651650.2
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 335992.16
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 252576.19
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 84134.63
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 42523.375
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 39783.42
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 34878.008
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 33865.957
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 32023.672
it  1 : torch.Size([6611, 3]) torch.Size([6611, 3])
loss 32023.672
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 30208.371
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 21619.375
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 19577.943
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 11050.457
torch.Size([6611, 3]) torch.Size([6611, 3])
loss 8453.696
torch.Size([6611, 3]) torch.Size([6611, 3])

KeyboardInterrupt: 

In [175]:
nt = 15
p,q = Shooting(p0, q0, momkern, nt=nt)[-1]

In [176]:
plots(q,VT, FS,FT)

In [281]:
listpq = Shooting(p0, q0, momkern, nt=nt)
save_folder = "../doc/_build/html/_images/"

VTnp, FTnp = VT.detach().cpu().numpy(), FT.detach().cpu().numpy()
q0np, FSnp = q0.detach().cpu().numpy(), FS.detach().cpu().numpy()

# Create figure
fig = go.Figure()
fig.add_trace(
    go.Mesh3d(
        visible=True,
        x=VTnp[:, 0],
        y=VTnp[:, 1],
        z=VTnp[:, 2],
        i=FTnp[:, 0],
        j=FTnp[:, 1],
        k=FTnp[:, 2],
    )
)

# Add traces, one for each slider step
for t in range(nt):
    qnp = listpq[t][1].detach().cpu().numpy()
    fig.add_trace(
        go.Mesh3d(
            visible=False,
            x=qnp[:, 0],
            y=qnp[:, 1],
            z=qnp[:, 2],
            i=FSnp[:, 0],
            j=FSnp[:, 1],
            k=FSnp[:, 2],
        )
    )

# Make 10th trace visible
fig.data[1].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data) - 1):
    step = dict(
        method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][0] = True
    step["args"][1][i + 1] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [
    dict(active=0, currentvalue={"prefix": "time: "}, pad={"t": 20}, steps=steps)
]

fig.update_layout(sliders=sliders)

fig.write_html(save_folder + "results.html", auto_open=False)