# Librairies

In [None]:
seed = 2008
import os
os.environ['PYTHONHASHSEED']=str(seed)
import time

import torch
torch.manual_seed(seed)
from torch.autograd import grad

import plotly.graph_objs as go

from pykeops.torch import Vi, Vj

# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device("cuda:0") if use_cuda else "cpu"
torchdtype = torch.float32

# PyKeOps counterpart
KeOpsdeviceId = torchdeviceId.index  # id of Gpu device (in case Gpu is  used)
KeOpsdtype = torchdtype.__str__().split(".")[1]  # 'float32'

import numpy as np
np.random.seed(seed)

import matplotlib.pyplot as plt
import torch.distributions as tdist
from torch.autograd import Variable

from geomloss import SamplesLoss
from imageio import imread
import random
random.seed(seed)
from random import choices

In [None]:
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

# Utils

## Data importation and visualization

Adapted from https://www.kernel-operations.io/geomloss/_auto_examples/comparisons/plot_gradient_flows_2D.html#sphx-glr-auto-examples-comparisons-plot-gradient-flows-2d-py

In [None]:
def load_image(fname):
    img = imread(fname, as_gray=True)  # Grayscale
    img = (img[::-1, :]) / 255.0
    return 1 - img

def draw_samples(fname, n, dtype=torch.FloatTensor):
    A = load_image(fname)
    xg, yg = np.meshgrid(np.linspace(0, 1, A.shape[0]), np.linspace(0, 1, A.shape[1]))

    grid = list(zip(xg.ravel(), yg.ravel()))
    dens = A.ravel() / A.sum()
    dots = np.array(choices(grid, dens, k=n))
    dots += (0.5 / A.shape[0]) * np.random.standard_normal(dots.shape)

    return torch.from_numpy(dots).type(dtype)

def display_samples(ax, x, color):
    x_ = x.detach().cpu().numpy()
    ax.scatter(x_[:, 0], x_[:, 1], 25 * 500 / len(x_), color, edgecolors="none")
    
def plot_particles(ax, x, y, colors):
    plt.set_cmap("hsv")
    plt.scatter(
        [10], [10]
    )  # shameless hack to prevent a slight change of axis...

    display_samples(ax, y, [(0.55, 0.55, 0.95)])
    cloud = display_samples(ax, x, colors)

    plt.axis([0, 1, 0, 1])
    plt.gca().set_aspect("equal", adjustable="box")
    plt.xticks([], [])
    plt.yticks([], [])
    return cloud

def plot_particles(x,y, colors,title):
    x_ = x.detach().cpu().numpy()
    y_ = y.detach().cpu().numpy()
    
    plt.figure()
    plt.set_cmap("hsv")
    plt.scatter(y_[:, 0], y_[:, 1], 25 * 500 / len(y_), [(0.55, 0.55, 0.95)], edgecolors="none")
    plt.scatter(x_[:, 0], x_[:, 1], 25 * 500 / len(x_), colors, edgecolors="none")
    
    plt.axis([0, 1, 0, 1])
    plt.gca().set_aspect("equal", adjustable="box")
    plt.xticks([], [])
    plt.yticks([], [])
    plt.gca().set_axis_off()
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.savefig(title, bbox_inches = 'tight',pad_inches = 0)
    
    plt.show()

## Geodesic shooting

Adapted from https://www.kernel-operations.io/keops/_auto_tutorials/surface_registration/plot_LDDMM_Surface.html#sphx-glr-auto-tutorials-surface-registration-plot-lddmm-surface-py

### Kernel

In [None]:
def GaussKernel(sigma):
    x, y, b = Vi(0, 2), Vj(1, 2), Vj(2, 2)
    gamma = 1 / (2*sigma * sigma)
    D2 = x.sqdist(y)
    K = (-D2 * gamma).exp()
    return ((0.3989/sigma)*K * b).sum_reduction(axis=1)

### Dynamic

In [None]:
def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(
                map(
                    lambda x, xdot, xdoti: x + (0.25 * dt) * (xdot + 3 * xdoti),
                    x,
                    xdot,
                    xdoti,
                )
            )
            l.append(x)
        return l

    return f

In [None]:
def Hamiltonian(K):
    def H(p, q):
        return  0.5*(p * K(q, q, p)).sum()

    return H


def HamiltonianSystem(K):
    H = Hamiltonian(K)

    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp

    return HS

In [None]:
def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)


def Flow(x0, p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)

    def FlowEq(x, p, q):
        return (K(x, q, p),) + HS(p, q)

    return Integrator(FlowEq, (x0, p0, q0), nt)#[0]

### Optimization problem

In [None]:
def LDDMMloss(K, dataloss, gamma=0):
    def loss(p0, q0):
        p, q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q)

    return loss

def Optimize(loss,p0,q0,lr=0.5,max_it=20):
    optimizer = torch.optim.LBFGS([p0], max_eval=10, max_iter=10, lr=lr)
    history = []
    print("performing optimization...")
    start = time.time()

    def closure():
        optimizer.zero_grad()
        L = loss(p0, q0)
        l = L.detach().cpu().numpy()
        print("loss", l)
        history.append(l)
        L.backward()
        return L


    for i in range(max_it):
        print("it ", i, ": ", end="")
        optimizer.step(closure)

    print("Optimization (L-BFGS) time: ", round(time.time() - start, 2), " seconds")
    return history

### Plot

In [None]:
def ShowCase(listx,y,title):
    x0 = listx[0]
    tau = len(listx)-1
    # Use colors to identify the particles
    colors = (10 * x0[:, 0]).cos() * (10 * x0[:, 1]).cos()
    colors = colors.detach().cpu().numpy()

    plot_particles(x0, y, colors,'0_'+title)

    for t in range(tau): # t+1 in (1,2,...,tau+1) 
        if (t+1)%4==0:
            plot_particles(listx[t+1], y, colors, "{}_".format(t+1)+title)

## Gradient descent on the time-dependent momentum

### Optimize

In [None]:
def learn_diffeo(loss,x,y,lamb,sigma,nt,lr,max_iter,tol=1e-8,init=None):

    K = GaussKernel(sigma=sigma)
    
    n,d = x.shape

    if init==None:
        a = torch.zeros([nt+1,n,d], dtype=torchdtype, device=torchdeviceId).float()
    else:
        a = init
    p = torch.zeros([nt+1,n,d], dtype=torchdtype, device=torchdeviceId).float()
    z = torch.zeros([nt+1,n,d], dtype=torchdtype, device=torchdeviceId).float()

    history = []

    z[0] = x

    t0 = time.time()
    
    objective = loss(x,y).item()
    back = objective+1
    it = 0

    while (it<max_iter) and (objective>tol):# and (objective<=back)
        back = objective   
        "Transform the inputs"
        s=0
        for t in range(nt): # t+1 in (1,2,...,tau)
            z[t+1] = z[t] + (1/nt)*K(z[t],z[t],a[t])
            s = s + (1/nt)*(a[t] * K(z[t], z[t], a[t])).sum()

        "Print the values"
        z1 = z[nt]
        z1.requires_grad = True
        
        loss_value = loss(z1,y)
        objective = loss(z1,y) + lamb*s
        history.append(objective.item())
        
        print("Iteration: "+str(it)+"  Fidelity loss: "+str(loss_value.item())+"  Objective: "+str(objective.item()))

        "Initialize p"
        [g] = torch.autograd.grad(loss_value, [z1])
        p[nt-1] = g

        "Solve the equation on p"
        a_i = Vi(a[t])
        a_j = Vj(a[t])
        z_i = Vi(z[t])
        z_j = Vj(z[t])
        p_i = Vi(p[t])
        p_j = Vj(p[t])

        D2 = z_i.sqdist(z_j)
        gamma = 1 / (2*sigma * sigma)
        Ker = (-D2 * gamma).exp()
        O = a_i * p_j + a_j * p_i - 2 * lamb * a_i * p_j

        for t in range(nt-1,0,-1): # t-1 in (tau-2,...,1,0)
            p[t-1] = p[t] + (1/nt)*(1/sigma**2)*((0.3989/sigma) * Ker * O * (z_i-z_j)).sum_reduction(axis=1)

        "Update the gradient descent"
        a = a - lr*(2*lamb*a+p)
        it += 1
    
    tf = time.time()
    print("Elapsed time: "+str(tf-t0))
    print("Averaged elapsed time per iteration: "+str((tf-t0)/it))
    
    return a,z,history

### Apply and plot

In [None]:
def apply_diffeo(x_new,y,a,z,sigma,title):

    "Initialization"
    m = x_new.shape[0]
    nt,n,d = z.shape
    Nt=nt-1
    z_new = torch.zeros([Nt+1,m,d], dtype=torchdtype, device=torchdeviceId)
    z_new[0] = x_new
    K = GaussKernel(sigma)
    
    # Use colors to identify the particles
    colors = (10 * x_new[:, 0]).cos() * (10 * x_new[:, 1]).cos()
    colors = colors.detach().cpu().numpy()
    plot_particles(z_new[0], y, colors,"0_"+title)

    "Iterations"
    for t in range(Nt): # t+1 in (1,2,...,tau+1)
        z_new[t+1] = z_new[t] + (1/(Nt-1))*K(z_new[t],z[t],a[t])
        
        if (t+1)%4==0:
            #print("Hi!")
            plot_particles(z_new[t+1], y, colors,"{}_".format(t+1)+title)

    return z_new[Nt]

# Demo

## Setup

### Training set

In [None]:
n = 1000
x = draw_samples("density_a.png", n, dtype)
y = draw_samples("density_b.png", n, dtype)
q0 = x.clone().detach().to(dtype=torchdtype, device=torchdeviceId).requires_grad_(True)
y = y.clone().detach().to(dtype=torchdtype, device=torchdeviceId)

### Testing set

In [None]:
m = 2000
x0 = draw_samples("density_a.png", m, dtype)
x1 = draw_samples("density_b.png", m, dtype)
x0.requires_grad = True

### Hilbert space

In [None]:
sigma = torch.tensor([0.175], dtype=torchdtype, device=torchdeviceId)
Kv = GaussKernel(sigma=sigma)   

### Regularization and discretization

In [None]:
nt = 16
lamb = 1e-8 

## Run

### Loss

In [None]:
#Title for the saved figures
title = 'SD_1e-3'

#GeomLoss
S = SamplesLoss("sinkhorn", p=2, blur=1e-2)
# S = SamplesLoss("sinkhorn", p=2, blur=1e-2, debias=False)
# S = SamplesLoss("gaussian", blur=1e-1)

### Geodesic shooting

In [None]:
title1 = title + "_gs"
lr_gs = 0.7
max_it_gs = 20
dataloss = lambda q : S(q,y)
loss = LDDMMloss(Kv, dataloss,gamma=2*lamb)

In [None]:
p0 = torch.zeros(q0.shape, dtype=torchdtype, device=torchdeviceId, requires_grad=True)
history = Optimize(loss,p0,q0,lr_gs,max_it_gs)

In [None]:
listxqp = Flow(x0, p0, q0, Kv, nt=nt)
listx = [listxqp[t][0] for t in range(nt+1)]

ShowCase(listx,x1,title1)

### Gradient descent

In [None]:
title2 = title + "_gd"
lr_gd=0.6
max_it_gd=200

In [None]:
a, z, history = learn_diffeo(S,x,y,lamb,sigma,nt,lr_gd,max_it_gd)

In [None]:
z_new = apply_diffeo(x0,x1,a,z,sigma,title2)