# Gradient flows of Potential Energies in the Geometry of Sinkhorn divergences: Numerics

Some of the code being relatively slow to run, we have saved the results in .pt files to be able to show them directly. We left the code that was used to produce these results as comments.

In [None]:
# Imports
import spf # Sinkhorn Potential Flows
import torch
import numpy as np
from plotly import graph_objects as go
from plotly import colors
from scipy.stats import norm
blues = colors.sequential.Blues
# Magic to make plotly render TeX in notebook
from plotly import offline
offline.init_notebook_mode()
from IPython.display import display, HTML
display(HTML(
    '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
))

# 3 point space

This section illustrates the RKHS sphere in 3D when the underlying space has 3 points (Figure 2 of the article).

In [None]:
X = torch.tensor([-.5, 0, 1])[:,None]
µ0 = torch.tensor([.4, .4, .2])
V = torch.tensor([.5, .3, 0.])
tau = 8e-3
N = 500
t = tau*torch.arange(N)
espf = spf.EulerianSPF(X, V, .5)
espf.set_optimizer(spf.optimizers.APGD(.15))

In [None]:
# µt, ft = espf.integrate(µ0, t, print_progress=True, max_sinkhorn_steps=30, max_optim_steps=45, sinkhorn_tol=1e-5, optim_tol=1e-10, return_ft=True)
# torch.save(ft, "./saved_flows/3points_ft_eps5e-01_tau8e-03_N500.pt")
ft = torch.load("./saved_flows/3points_ft_eps5e-01_tau8e-03_N500.pt")

In [None]:
# Animation
sphere, traj = spf.visualize.b_flow_sphere(espf.cost_matrix, espf.eps, ft[::5,:], potential_array=V,
                                    B_kwargs=dict(name='Boundary of \u212c', line=dict(width=10, color='orange', dash='dash')),
                                    rotation_lines_kwargs=dict(name='Theoretical rotation', showlegend=True),
                                    sphere_kwargs=dict(colorscale=blues, showlegend=True, name='Potential Energy'),
                                    flow_kwargs=dict(color='red', line=dict(width=15), name='Embedded flow trajectory'))
scatter_b0 = go.Scatter3d(x=traj[0].x, y=traj[0].y, z=traj[0].z, mode="markers", marker=dict(symbol='x', color='darkred', size=4), name='b<sub>0</sub>')
fig = go.Figure(data=sphere+[traj[0], scatter_b0], layout=go.Layout(
                         height=700,
                         width=1000,
                         scene=dict(
                         xaxis=dict(range=[-1, 1], autorange=False, visible=False),
                         yaxis=dict(range=[-1, 1], autorange=False, visible=False),
                         zaxis=dict(range=[-1, 1], autorange=False, visible=False),
                         aspectmode='cube'),
                         title=dict(text='Embedded Sinkhorn Potential flow animation', font=dict(size=30)), title_x=.5,
                         legend=dict(yanchor="top",
                                     y=0.9,
                                     xanchor="left",
                                     x=0.99,
                                     font=dict(size=20),
                                     itemsizing='constant'),
                         margin=dict(l=0, r=0, b=0),
                         updatemenus=[dict(type="buttons", showactive=False,
                                           buttons=[dict(label='Play',
                                                         method='animate',
                                                         args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                          fromcurrent=True)]),  
                                                    dict(label='Pause',
                                                         method='animate',
                                                         args=[[None], dict(frame=dict(duration=0, redraw=True),
                                                                            mode='immediate',
                                                                            transition=dict(duration=0))])])]),
                         )
fig.frames = [go.Frame(data=sphere+[line, scatter_b0]) for line in traj]
fig.show()
# fig.write_html('./animations/sphere.html')

In [None]:
# Static figure
scatter_b4 = go.Scatter3d(x=traj[-1].x[-1:], y=traj[-1].y[-1:], z=traj[-1].z[-1:], mode="markers", marker=dict(symbol='circle', color='red', size=8), name='b<sub>4</sub>')
go.Figure(data=sphere+[traj[-1], scatter_b0, scatter_b4], layout=go.Layout(
                         height=700,
                         width=900,
                         scene=dict(
                         xaxis=dict(range=[-1, 1], autorange=False, visible=False),
                         yaxis=dict(range=[-1, 1], autorange=False, visible=False),
                         zaxis=dict(range=[-1, 1], autorange=False, visible=False),
                         aspectmode='cube'),
                         title=dict(text='Embedded Sinkhorn Potential flow', font=dict(size=30)), title_x=.5,
                         legend=dict(yanchor="top",
                                     y=0.9,
                                     xanchor="left",
                                     x=0.99,
                                     font=dict(size=20),
                                     itemsizing='constant'),
                         margin=dict(l=0, r=0, b=0),)).show()

# Convex potential

This section explores convex potentials (Figures 4 and 5 in the article).

In [None]:
# Setup
n = 100
X = torch.linspace(0, 1, n)[:,None]
k = int(.8*n)
class Gaussian1D:
    def __init__(self, mean, sigma):
        self.mean = mean
        self.sigma = sigma
    def __call__(self, t) -> torch.Tensor:
        µ = torch.exp(-.5*((t-self.mean)/self.sigma)**2)
        µ /= µ.sum()
        return µ
gauss = Gaussian1D(X[k,0], .07)
V = lambda x: x[:,0]**2
tau = 5e-2
N = 21
t = tau*torch.arange(N)
eps = .2**2

## Eulerian

In [None]:
µ0 = gauss(X.flatten())
espf = spf.EulerianSPF(X, V, eps)
espf.set_optimizer(spf.optimizers.APGD(1e-2))

In [None]:
# µt = espf.integrate(µ0, t, print_progress=True, max_sinkhorn_steps=1000, max_optim_steps=10000, sinkhorn_tol=1e-6, optim_tol=1e-6)
# torch.save(µt, f"saved_flows/eul_cvx_eps{espf.eps:.0e}_tau{t[1]-t[0]:.0e}_N{N}.pt")
µt = torch.load(f"saved_flows/eul_cvx_eps{espf.eps:.0e}_tau{t[1]-t[0]:.0e}_N{N}.pt")

In [None]:
# Animation
data = spf.visualize.mass_flow(X, µt, marker=dict(color='red',line=dict(width=0)), name='Sinkhorn flow', width=[1/n]*n)
potential_plot = go.Scatter(x=torch.linspace(0, 1, 100), y=V(torch.linspace(0, 1, 100)[:,None]), yaxis='y2', name='V', line=dict(color='blue'))
fig = go.Figure(data=(data[0], potential_plot),
                    layout=go.Layout(
                        height=800,
                        width=900,
                        title=dict(text="Sinkhorn potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[0, 1], autorange=False, tickfont=dict(size=20)),
                        yaxis=dict(range=[0, 1], autorange=False, title=dict(text='Mass', font=dict(color="red", size=20)), tickfont=dict(size=20)),
                        yaxis2=dict(side='right', overlaying='y', title=dict(text='Potential', font=dict(color="blue", size=20)), tickfont=dict(size=20)),
                        legend=dict(font=dict(size=20)),
                        updatemenus=[dict(type="buttons",
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                           fromcurrent=True)]),
                                                   dict(label='Pause',
                                                        method='animate',
                                                        args=[[None], dict(frame=dict(duration=0, redraw=False),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
fig.frames = [go.Frame(data=(d, potential_plot)) for d in data]
fig.show()
fig.write_html(f"animations/EulerianSinkhorn1Dcvx.html")

In [None]:
# Static images, modify index k to show the different frames
k = 20
fig = go.Figure(data=(data[k], potential_plot),
                    layout=go.Layout(
                        height=700,
                        width=700,
                        margin=dict(l=3, r=1, b=2, t=0),
                        plot_bgcolor='white',
                        xaxis=dict(range=[0, 1], autorange=False, tickvals=np.arange(0, 1, np.sqrt(espf.eps)), ticktext=[""]*int(1/np.sqrt(espf.eps)), showline=True, linewidth=2, linecolor='black', ticks="inside", ticklen=10, tickwidth=3),
                        yaxis=dict(range=[0, 1], autorange=False,showticklabels=False),
                        yaxis2=dict(side='right', overlaying='y', showticklabels=False),
                        legend=dict(visible=False)
                    ),
            )
fig.update_yaxes(showline=True, linewidth=3, linecolor='black', mirror=True)
fig.show()

In [None]:
# Wasserstein animation
µt_wasserstein = torch.cat([Gaussian1D(torch.exp(-2*s)*gauss.mean, torch.exp(-s)*gauss.sigma)(X.flatten())[None,:] for s in t], dim=0)
data_wasserstein = spf.visualize.mass_flow(X, µt_wasserstein, marker=dict(color="#000000",line=dict(width=0)), name='Wasserstein flow', width=[1/n]*n)
fig = go.Figure(data=(data_wasserstein[0], potential_plot),
                    layout=go.Layout(
                        height=800,
                        width=900,
                        title=dict(text="Wasserstein potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[0, 1], autorange=False, tickfont=dict(size=20)),
                        yaxis=dict(range=[0, 1], autorange=False, title=dict(text='Mass', font=dict(color="black", size=20)), tickfont=dict(size=20)),
                        yaxis2=dict(side='right', overlaying='y', title=dict(text='Potential', font=dict(color="blue", size=20)), tickfont=dict(size=20)),
                        legend=dict(font=dict(size=20)),
                        updatemenus=[dict(type="buttons",
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                           fromcurrent=True)]),
                                                   dict(label='Pause',
                                                        method='animate',
                                                        args=[[None], dict(frame=dict(duration=0, redraw=False),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
fig.frames = [go.Frame(data=(d, potential_plot)) for d in data_wasserstein]
fig.show()
fig.write_html(f"animations/Wasserstein1Dcvx.html")

In [None]:
# Static images, modify index k to show the different frames
k = 20
fig = go.Figure(data=(data_wasserstein[k], potential_plot),
                    layout=go.Layout(
                        height=700,
                        width=700,
                        margin=dict(l=2, r=1, b=2, t=0),
                        plot_bgcolor='white',
                        xaxis=dict(range=[0, 1], autorange=False, showticklabels=False, showline=True, linewidth=2, linecolor='black'),
                        yaxis=dict(range=[0, 1], autorange=False,showticklabels=False),
                        yaxis2=dict(side='right', overlaying='y', showticklabels=False),
                        legend=dict(visible=False)
                    ),
            )
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=True)
fig.show()

## Lagrangian

### 1D

In [None]:
# Setup
x = torch.tensor(norm.ppf(np.linspace(.003, .997, 60), loc=gauss.mean, scale=gauss.sigma))[:,None]
lspf = spf.LagrangianSPF(V, eps)
lspf.set_optimizer(spf.optimizers.NesterovGD(lr=lambda k: 1e-2*(1+np.log(np.log(k+1)))))

In [None]:
# xt = lspf.integrate(x, t, max_sinkhorn_steps=500, max_optim_steps=1500, sinkhorn_tol=1e-6, print_progress=True)
# torch.save(xt, f"./saved_flows/lag_cvx_eps{lspf.eps:.0e}_tau{t[1]-t[0]:.0e}_N{N}.pt")
xt = torch.load(f"./saved_flows/lag_cvx_eps{lspf.eps:.0e}_tau{t[1]-t[0]:.0e}_N{N}.pt")

In [None]:
# Animation
data = spf.visualize.particles_to_bars(xt[:, :,0], X, width=1/n, marker=dict(color="#FF7A00"), name='Particles')
potential_plot = go.Scatter(x=torch.linspace(0, 1, 100), y=V(torch.linspace(0, 1, 100)[:,None]), name='V', line=dict(color='blue'), yaxis='y2')
fig = go.Figure(data=(data[0], potential_plot),
                    layout=go.Layout(
                        height=800,
                        width=900,
                        title=dict(text="Sinkhorn potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[0, 1], autorange=False, tickfont=dict(size=20)),
                        yaxis=dict(range=[0, 1], title=dict(text='Mass', font=dict(color="#FF7A00", size=20)), tickfont=dict(size=20)),
                        yaxis2=dict(side='right', overlaying='y', title=dict(text='Potential', font=dict(color="blue", size=20)), tickfont=dict(size=20)),
                        legend=dict(font=dict(size=20)),
                        updatemenus=[dict(type="buttons",
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                           fromcurrent=True)]),
                                                   dict(label='Pause',
                                                        method='animate',
                                                        args=[[None], dict(frame=dict(duration=0, redraw=False),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
fig.frames = [go.Frame(data=(d, potential_plot)) for d in data]
fig.show()
# fig.write_html(f"animations/LagrangianSinkhorn1Dcvx.html")

In [None]:
# Static images, modify index k to show the different frames
k = 20
fig = go.Figure(data=(data[k], potential_plot),
                    layout=go.Layout(
                        height=700,
                        width=700,
                        margin=dict(l=2, r=1, b=2, t=0),
                        plot_bgcolor='white',
                        xaxis=dict(range=[0, 1], autorange=False, tickvals=np.arange(0, 1, np.sqrt(lspf.eps)), ticktext=[""]*int(1/np.sqrt(lspf.eps)), showline=True, linewidth=2, linecolor='black', ticks="inside", ticklen=10, tickwidth=3),
                        yaxis=dict(range=[0, 1], autorange=False,showticklabels=False),
                        yaxis2=dict(side='right', overlaying='y', showticklabels=False),
                        legend=dict(visible=False)
                    ),
            )
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=True)
fig.show()

### 2D

In [None]:
# Setup
V = spf.utils.sqnorm
lspf = spf.LagrangianSPF(V, .2**2)
lspf.set_optimizer(spf.optimizers.NesterovGD(lr=lambda k: 1/np.log(k+1)))
X0 = spf.utils.grid[.6:1:.3, .6:1:.3].reshape(-1, 2)
t = 1e-2*torch.arange(60)

In [None]:
# xt = lspf.integrate(X0, t, max_sinkhorn_steps=500, max_optim_steps=500, sinkhorn_tol=5e-6, print_progress=True)
# torch.save(xt, f"./saved_flows/lag_cvx2D_eps4e-02_tau1e-02_N60.pt")
xt = torch.load(f"./saved_flows/lag_cvx2D_eps4e-02_tau1e-02_N60.pt")

In [None]:
# Animation
data = spf.visualize.particle_flow(xt, marker=dict(color='#FF7A00', size=25), name='Particles')
m, M = -.1, 1
fig = go.Figure(data=[data[0]],
                    layout=go.Layout(
                        height=800,
                        width=800,
                        title=dict(text="Sinkhorn potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[m, M], autorange=False, tickfont=dict(size=20), visible=False),
                        yaxis=dict(range=[m, M], autorange=False, tickfont=dict(size=20), visible=False),
                        updatemenus=[dict(type="buttons", showactive=False,
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                         fromcurrent=True)]),  
                                                    dict(label='Pause',
                                                         method='animate',
                                                         args=[[None], dict(frame=dict(duration=0, redraw=True),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
rooteps = np.sqrt(lspf.eps)
for xv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=xv, y0=m,
        x1=xv, y1=M,
        line=dict(color="black", width=1)
    )

for yv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=m, y0=yv,
        x1=M,  y1=yv,
        line=dict(color="black", width=1)
    )
fig.frames = [go.Frame(data=[d]) for d in data]

fig.show()
# fig.write_html(f"animations/LagrangianSinkhorn2Dcvx.html")

In [None]:
# Static images, modify index k to show the different frames
k = 45
fig = go.Figure(data=[go.Scatter(x=xt[k,:, 0], y=xt[k,:, 1], mode='markers',
                                          marker=dict(color='#FF7A00', size=25), name='Particles', showlegend=False),
                              spf.visualize.potential_heatmap(V, [m,M, m, M], grid_size=200, name='V', showlegend=False, showscale=False, colorscale=blues)],
                        layout=dict(height=700, width=700, xaxis=dict(range=[m, M], visible=False), yaxis=dict(range=[m, M], visible=False),
                                    margin=dict(l=0, r=0, b=0,t=0)))
for xv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=xv, y0=m,
        x1=xv, y1=M,
        line=dict(color="black", width=1)
    )

for yv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=m, y0=yv,
        x1=M,  y1=yv,
        line=dict(color="black", width=1)
    )
fig.show()

In [None]:
# Wasserstein animation
Xt_ = torch.einsum('t,nd->tnd',torch.exp(-2*t),X0)
data_ = spf.visualize.particle_flow(Xt_, marker=dict(color='black', size=20))
m, M = -.1, 1
fig = go.Figure(data=[data_[0]],
                    layout=go.Layout(
                        height=800,
                        width=800,
                        title=dict(text="Wasserstein potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[m, M], autorange=False, tickfont=dict(size=20), visible=False),
                        yaxis=dict(range=[m, M], autorange=False, tickfont=dict(size=20), visible=False),
                        updatemenus=[dict(type="buttons", showactive=False,
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                         fromcurrent=True)]),  
                                                    dict(label='Pause',
                                                         method='animate',
                                                         args=[[None], dict(frame=dict(duration=0, redraw=True),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
fig.frames = [go.Frame(data=[d]) for d in data_]
fig.show()
fig.write_html(f"animations/Wasserstein2Dcvx.html")

In [None]:
# Static images, modify index k to show the different frames
k = 45
fig = go.Figure(data=[go.Scatter(x=Xt_[k,:, 0], y=Xt_[k,:, 1], mode='markers',
                                    marker=dict(color="black", size=20), name='Particles', showlegend=False),
                        spf.visualize.potential_heatmap(V, [-1,1, -1, 1], grid_size=200, name='V', showlegend=False, showscale=False, colorscale=blues)],
                  layout=dict(height=700, width=700, xaxis=dict(range=[-.2, 1], visible=False), yaxis=dict(range=[-.2, 1], visible=False),
                              margin=dict(l=0, r=0, b=0,t=0)))
fig.show()

# Non-convex potential

This section explores non-convex potentials (Figure 6 in the article).

In [None]:
# Setup
P, L = .4, .2
def noncvx(P, L):
    def f(t):
        res = torch.empty_like(t)
        i1 = (t<P)
        i3 = (P+L<t)
        i2 = (i1 | i3).logical_not()
        res[i1] = t[i1]**2
        res[i2] = P**2 + (P*L/np.pi)*torch.sin(2*np.pi*(t[i2]-P)/L)
        res[i3] = t[i3]**2 -2*L*t[i3] + L*L
        return res
    return lambda t: f(t.flatten())
P, L = .4, .2
V = noncvx(P, L)
tau = 5e-2
N = 148
n = 100
X = torch.linspace(0, 1, n)[:,None]

## Eulerian

In [None]:
# Initialization
µ0 = spf.utils.unif(.8, .9, X)
t = tau*torch.arange(N)
sf = spf.EulerianSPF(X, V, .2**2)
sf.set_optimizer(spf.optimizers.APGD(lambda k: 1e-2))

In [None]:
# µt = sf.integrate(µ0, t, print_progress=True, max_sinkhorn_steps=500, max_optim_steps=5000, sinkhorn_tol=1e-6, optim_tol=1e-6)
# torch.save(µt, "saved_flows/eul_ncvx_eps4e-02_tau5e-02_N148.pt")
µt = torch.load("saved_flows/eul_ncvx_eps4e-02_tau5e-02_N148.pt")

In [None]:
data = spf.visualize.mass_flow(X, µt, marker=dict(color='red', line=dict(width=0)), width=[1/n]*n, name='Sinkhorn flow')
potential_plot = go.Scatter(x=torch.linspace(0, 1, 200), y=V(torch.linspace(0, 1, 200)[:,None]), yaxis='y2', name='V', line=dict(color="blue"))
fig = go.Figure(data=(data[0], potential_plot),
                    layout=go.Layout(
                        height=800,
                        width=900,
                        title=dict(text="Sinkhorn potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[0, 1], autorange=False, tickfont=dict(size=20)),
                        yaxis=dict(range=[0, 1], autorange=False, title=dict(text='Mass', font=dict(color="red", size=20)), tickfont=dict(size=20)),
                        yaxis2=dict(side='right', overlaying='y', title=dict(text='Potential', font=dict(color="blue", size=20)), tickfont=dict(size=20)),
                        legend=dict(font=dict(size=20)),
                        updatemenus=[dict(type="buttons",
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=1, redraw=True),
                                                                           fromcurrent=True)]),
                                                   dict(label='Pause',
                                                        method='animate',
                                                        args=[[None], dict(frame=dict(duration=0, redraw=False),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
fig.frames = [go.Frame(data=(d, potential_plot)) for d in data]
fig.show()
# fig.write_html(f"animations/EulerianSinkhorn1Dnoncvx.html")

In [None]:
k = 10
fig = go.Figure(data=(data[k], potential_plot),
                    layout=go.Layout(
                        height=700,
                        width=700,
                        margin=dict(l=2, r=1, b=2, t=0),
                        plot_bgcolor='white',
                        xaxis=dict(range=[0, 1], autorange=False, tickvals=np.arange(0, 1, .2), ticktext=[""]*int(1/.2), showline=True, linewidth=2, linecolor='black', ticks="inside", ticklen=10, tickwidth=3),
                        yaxis=dict(range=[0, 1], autorange=False,showticklabels=False),
                        yaxis2=dict(side='right', overlaying='y', showticklabels=False),
                        legend=dict(visible=False)
                    ),
            )
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=True)
fig.show()

## Lagrangian

### 1D

In [None]:
# Setup
V = noncvx(P, L)
lspf = spf.LagrangianSPF(V, .2**2)
lspf.set_optimizer(spf.optimizers.NesterovGD(lr=1e-2))
X0 = torch.arange(.8, .9, .01)[:,None]
t = 5e-2*torch.arange(150)

In [None]:
# xt = lspf.integrate(X0, t, max_sinkhorn_steps=100, max_optim_steps=100, sinkhorn_tol=1e-5, print_progress=True)
# torch.save(xt, "./saved_flows/lag_ncvx_eps4e-02_tau5e-02_N150.pt")
xt = torch.load("./saved_flows/lag_ncvx_eps4e-02_tau5e-02_N150.pt")

In [None]:
# Animation
data = spf.visualize.particles_to_bars(xt[:, :,0], X, width=1/n, marker=dict(color='#FF7A00'), name='Particles')
potential_plot = go.Scatter(x=torch.linspace(0, 1, 100), y=V(torch.linspace(0, 1, 100)[:,None]), name='V', line=dict(color='blue'), yaxis='y2')
fig = go.Figure(data=(data[0], potential_plot),
                    layout=go.Layout(
                        height=800,
                        width=900,
                        title=dict(text="Sinkhorn potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[0, 1], autorange=False, tickfont=dict(size=20)),
                        yaxis=dict(range=[0, 1], title=dict(text='Mass', font=dict(color="#FF7A00", size=20)), tickfont=dict(size=20)),
                        yaxis2=dict(side='right', overlaying='y', title=dict(text='Potential', font=dict(color="blue", size=20)), tickfont=dict(size=20)),
                        legend=dict(font=dict(size=20)),
                        updatemenus=[dict(type="buttons",
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                           fromcurrent=True)]),
                                                   dict(label='Pause',
                                                        method='animate',
                                                        args=[[None], dict(frame=dict(duration=0, redraw=False),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
fig.frames = [go.Frame(data=(d, potential_plot)) for d in data]
fig.show()
# fig.write_html(f"animations/LagrangianSinkhorn1Dnoncvx.html")

In [None]:
# Static images, modify index k to show the different frames
k = 80
fig = go.Figure(data=(data[k], potential_plot),
                    layout=go.Layout(
                        height=700,
                        width=700,
                        margin=dict(l=2, r=1, b=2, t=0),
                        plot_bgcolor='white',
                        xaxis=dict(range=[0, 1], autorange=False, tickvals=np.arange(0, 1, .2), ticktext=[""]*int(1/.2), showline=True, linewidth=2, linecolor='black', ticks="inside", ticklen=10, tickwidth=3),
                        yaxis=dict(range=[0, 1], autorange=False,showticklabels=False),
                        yaxis2=dict(side='right', overlaying='y', showticklabels=False),
                        legend=dict(visible=False)
                    ),
            )
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=True)
fig.show()

### 2D

In [None]:
# Setup
V = lambda x: (lambda t: 1 - torch.cos(3*torch.pi*t) + t*t)(spf.utils.sqnorm(x))
lspf = spf.LagrangianSPF(V, 1)
lspf.set_optimizer(spf.optimizers.NesterovGD(lr=lambda k: 1e-2))
X0 = spf.utils.grid[.5:.9:.2, .5:.9:.2].reshape(-1, 2)
t = 1e-3*torch.arange(50)

In [None]:
# xt = lspf.integrate(X0, t, max_sinkhorn_steps=500, max_optim_steps=400, sinkhorn_tol=1e-7, print_progress=True)
# torch.save(xt, "./saved_flows/lag_ncvx2D_eps1_tau1e-3_N50.pt")
xt = torch.load("./saved_flows/lag_ncvx2D_eps1_tau1e-3_N50.pt")

In [None]:
# Animation
data = spf.visualize.particle_flow(xt, marker=dict(color='red', size=20))
m, M = -.1, 1
fig = go.Figure(data=[data[0]],
                    layout=go.Layout(
                        height=800,
                        width=800,
                        title=dict(text="Sinkhorn potential flow animation", font=dict(size=40)),
                        title_x=.5,
                        xaxis=dict(range=[m, M], autorange=False, tickfont=dict(size=20), visible=False),
                        yaxis=dict(range=[m, M], autorange=False, tickfont=dict(size=20), visible=False),
                        updatemenus=[dict(type="buttons", showactive=False,
                                          buttons=[dict(label='Play',
                                                        method='animate',
                                                        args=[None, dict(frame=dict(duration=0, redraw=True),
                                                                         fromcurrent=True)]),  
                                                    dict(label='Pause',
                                                         method='animate',
                                                         args=[[None], dict(frame=dict(duration=0, redraw=True),
                                                                           mode='immediate',
                                                                           transition=dict(duration=0))])])]
                    ),
            )
rooteps = np.sqrt(lspf.eps)
for xv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=xv, y0=m,
        x1=xv, y1=M,
        line=dict(color="black", width=1)
    )

for yv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=m, y0=yv,
        x1=M,  y1=yv,
        line=dict(color="black", width=1)
    )
fig.frames = [go.Frame(data=[d]) for d in data]

fig.show()
# fig.write_html(f"animations/LagrangianSinkhorn2Dnoncvx.html")

In [None]:
# Static images, modify index k to show the different frames
k = 45
fig = go.Figure(data=[go.Scatter(x=xt[k,:, 0], y=xt[k,:, 1], mode='markers', marker=dict(color='red', size=20),
                                    name='Particles'),
                        spf.visualize.potential_heatmap(V, [0,1, 0, 1], name='V', showlegend=False, showscale=False,
                                                        colorscale=blues)],
                    layout=dict(height=700, width=700, xaxis=dict(range=[0, 1], visible=False),
                                yaxis=dict(range=[0, 1], visible=False), margin=dict(l=0, r=0, b=0,t=0)))
for xv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=xv, y0=m,
        x1=xv, y1=M,
        line=dict(color="black", width=1)
    )

for yv in np.arange(m, M+rooteps, rooteps):
    fig.add_shape(
        type="line",
        x0=m, y0=yv,
        x1=M,  y1=yv,
        line=dict(color="black", width=1))

fig.show()


# Illustration of cost of vertical perturbation on the 2 point space

Here we simply compute the cost for the Sinkhorn metric tensor of a vertical perturbation on the two-point space (Figure 7 in the article).

In [None]:
m = torch.arange(1e-3, 1, 1e-3)
x = torch.tensor([[0, 0.], [0, 1]])
c = spf.utils.euclidean_cost(x, x)
eps = .2**2
vals = []
b = torch.exp(-2*spf.utils.sqnorm(x[0] - x[1])/eps)
M = torch.sqrt(m*(1-m))
p = (-b + torch.sqrt(b*b +4*M*M*b*(1-b)))/(2*(1-b))
g = .5*eps*(M*M-p)/(p*(2*M*M - p))

In [None]:
fig = go.Figure(data=go.Scatter(x=m, y=g, line=dict(color="green", width=3)), layout=dict(width=700, height=700, margin=dict(t=0, b=80, r=60, l=60), plot_bgcolor='white', xaxis=dict(title=r'$\huge m$', range=[-.05, 1.05], tickfont=dict(size=20), tickvals=torch.linspace(0, 1, 6), showline=True, linewidth=2, linecolor='black', ticklen=10, tickwidth=3, ticks='inside'), yaxis=dict(title=r'$\huge{g_\mu(\dot{\mu}, \dot{\mu})}$', showticklabels=False, showline=True, linewidth=2, linecolor='black')))
fig.write_image('./figures/vertical_perturbation_cost.png')
fig.show()