# <center>Formulación de Benamou-Brenier</center>

In [1]:
import numpy as np
from scipy import sparse
from scipy.sparse.linalg import spsolve
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import math

# Código adaptado desde https://github.com/Benoit-Muller/Computational-Optimal-Transport.

## Ecuación de Poisson 3D

Ecuación de Poisson 3D con una condición de Neumann y dos condiciones de borde periódicas.

In [2]:
def laplacian_matrix(n):
    An,Ap = derivative_matrices(n)
    In = sparse.eye(n,format="csr")
    Ip = sparse.eye(n-1,format="csr")
    D2x = sparse.kron(Ip,sparse.kron(Ip,An))
    D2y = sparse.kron(Ip,sparse.kron(Ap,In))
    D2z = sparse.kron(Ap,sparse.kron(Ip,In))
    A = D2x + D2y + D2z
    return A

def extend(g):
    n = np.shape(g)[1]
    G = np.zeros((n,n,n))
    G[0],G[-1] = g[1],g[0]
    return G

def normalize_lagrange(A, b):
    n=len(b)
    one = sparse.bsr_matrix(np.ones((n,1)))
    A = sparse.vstack((A, one.T))
    one_zero = sparse.vstack((one, sparse.bsr_matrix([0])))
    A = sparse.hstack((A, one_zero))
    b = np.hstack((b, [0]))
    return A,b

def poisson(f, g, A=None):
    n = np.shape(f)[0]
    order = "F"
    if A is None:
        A = laplacian_matrix(n)
    g_contribution = extend(g)[:,0:-1,0:-1]
    g_contribution[-1] = - g_contribution[-1]
    b = (f[:,0:-1,0:-1] + 2*n*g_contribution).flatten(order) / n**2
    b = b - np.sum(b)
    A,b = normalize_lagrange(A,b)
    u_vect = spsolve(A,b)
    u_vect = u_vect[:-1] 
    u = np.zeros((n,n,n))
    u[:,0:-1,0:-1] = u_vect.reshape((n,n-1,n-1),order=order)
    u[:,-1,0:-1] = u[:,0,0:-1]
    u[:,:,-1] = u[:,:,0]
    return u

def build_Asolve(n):
    A = laplacian_matrix(n)
    one = sparse.bsr_matrix(np.ones((n*(n-1)**2,1)))
    A = sparse.vstack((A, one.T))
    one_zero = sparse.vstack((one, sparse.bsr_matrix([0])))
    A = sparse.hstack((A, one_zero))
    Asolve = sparse.linalg.factorized(A)
    return Asolve

def poisson2(f, g, Asolve=None):
    n = np.shape(f)[0]
    if Asolve is None:
        Asolve = build_Asolve(n)
    order = "F"
    g_contribution = extend(g)[:,0:-1,0:-1]
    g_contribution[-1] = - g_contribution[-1]
    b = (f[:,0:-1,0:-1] + 2*n*g_contribution).flatten(order) / n**2
    b = b - np.sum(b)
    b = np.hstack((b, [0]))
    u_vect = Asolve(b)
    u_vect = u_vect[:-1]     
    u = np.zeros((n,n,n))
    u[:,0:-1,0:-1] = u_vect.reshape((n,n-1,n-1),order=order)
    u[:,-1,0:-1] = u[:,0,0:-1]
    u[:,:,-1] = u[:,:,0]
    return u

def derivative_matrices(n):
    Ap = sparse.diags([1, -2, 1], [-1, 0, 1], shape=(n-1, n-1))
    Ap = Ap.toarray()
    Ap[0,-1], Ap[-1,0] = 1, 1
    Ap = sparse.csc_matrix(Ap)
    An = sparse.diags([1, -2, 1], [-1, 0, 1], shape=(n, n))
    An = An.toarray()
    An[0,1], An[-1,-2] = 2, 2
    An = sparse.csc_matrix(An)
    return An, Ap

def divergence(field, An, Ap):
    d = np.empty_like(field[0])
    d[:,0:-1,0:-1] = np.einsum("ij,jkl->ikl",An.toarray(),field[0,:,0:-1,0:-1])
    d[:,0:-1,0:-1] = d[:,0:-1,0:-1] + np.einsum("ij,kjl->kil",Ap.toarray(),field[1,:,0:-1,0:-1])
    d[:,0:-1,0:-1] = d[:,0:-1,0:-1] + np.einsum("ij,klj->kli",Ap.toarray(),field[2,:,0:-1,0:-1])
    d[:,-1,0:-1] = d[:,0,0:-1]
    d[:,:,-1] = d[:,:,0]
    return d

def gradient(f, An, Ap):
    g = np.zeros((3,)+np.shape(f))
    g[0,:,0:-1,0:-1] = np.einsum("ij,jkl->ikl",An.toarray(),f[:,0:-1,0:-1])
    g[1,:,0:-1,0:-1] = np.einsum("ij,kjl->kil",Ap.toarray(),f[:,0:-1,0:-1])
    g[2,:,0:-1,0:-1] = np.einsum("ij,klj->kli",Ap.toarray(),f[:,0:-1,0:-1])
    g[:,:,-1,0:-1] = g[:,:,0,0:-1]
    g[:,:,:,-1] = g[:,:,:,0]
    return g

## Problema de Benamou-Brenier

Solución numérica para el problema de Benamou-Brenier.

In [3]:
class TransportProblem:
    def __init__(self, mesh, mu, nu, T, tau=1, Afactorized=True):
        (d,*space_grid_shape) = mesh.shape
        space_grid_shape = tuple(space_grid_shape)
        spacetime_grid_shape = (T,) + space_grid_shape
        self.spacetime_grid_shape = spacetime_grid_shape
        self.d = d
        self.mesh = mesh
        self.mu = mu
        self.nu = nu
        self.T = T
        self.times = np.linspace(0,1,T)
        self.rho = (1-self.times.reshape((T,)+d*(1,)))*mu + self.times.reshape((T,)+d*(1,))*nu
        eps = 0.1
        self.rho = (1-eps)*self.rho + eps/np.prod(spacetime_grid_shape)
        self.m = np.zeros((self.d,) + spacetime_grid_shape)
        self.M = np.concatenate((self.rho[np.newaxis,...], self.m))
        self.phi = np.zeros(spacetime_grid_shape)
        self.nabla_phi = np.zeros((d+1,) + spacetime_grid_shape)
        self.laplacian_matrix = laplacian_matrix(space_grid_shape[0])
        self.An, self.Ap = derivative_matrices(space_grid_shape[0])
        self.Afactorized = Afactorized
        if Afactorized == True:
            self.Asolve = build_Asolve(self.T)
        self.a = np.zeros(spacetime_grid_shape)
        self.b = np.zeros((d,)+ spacetime_grid_shape)
        self.c = np.concatenate((self.a[np.newaxis,...], self.b))
        self.tau=tau
    
    def update_c(self):
         self.c = np.concatenate((self.a[np.newaxis,...], self.b))

    def update_rho_m(self):
        self.rho = self.M[0]
        self.m = self.M[1:]

    def poisson_step1(self):
        g = np.stack((self.mu, self.nu))/self.tau - self.rho[[0,-1]]/self.tau + self.a[[0,-1]]
        f = divergence(self.c-self.M/self.tau, self.An, self.Ap)
        self.phi = poisson(f,g,self.laplacian_matrix)
        self.nabla_phi = gradient(self.phi,self.An,self.Ap)

    def poisson_step2(self):
        g = np.stack((self.mu, self.nu))/self.tau - self.rho[[0,-1]]/self.tau + self.a[[0,-1]]
        f = divergence(self.c-self.M/self.tau, self.An, self.Ap)
        self.phi = poisson2(f,g,self.Asolve)
        self.nabla_phi = gradient(self.phi,self.An,self.Ap)

    def poisson_step(self):
        if self.Afactorized==True:
            self.poisson_step2()
        else:
            self.poisson_step1()
    
    def projection_step_bis(self):
        alpha_beta = self.nabla_phi + self.M / self.tau
        alpha,beta = alpha_beta[0], np.sqrt(np.sum(alpha_beta[1:]**2,axis=0))
        iterator = np.ndindex(self.spacetime_grid_shape)
        for index in iterator:
            if np.max(alpha[index] + beta[index]**2/2,) > 0:
                a,b,c = 4-2*alpha[index], 4-8*alpha[index], 4*beta[index]**2-8*alpha[index]
                t = self.last_root(a,b,c)
                self.a[index] = alpha[index] - 1/2*t
                self.b[(...,*index)] = alpha_beta[1:][(...,*index)]/(1/2*t+1)
        self.update_c()

    def dual_step(self):
        self.M = self.M - self.tau*(self.c - self.nabla_phi)
        self.update_rho_m()

    def residual(self):
        return self.nabla_phi[0] + 0.5 * np.sum(self.nabla_phi[1:]**2,axis=0)
    
    def criterium(self):
        try:
            return np.sum(self.rho * np.abs(self.residual())) / np.sum(self.rho * np.sum(self.nabla_phi[1:]**2,axis=0))
        except ZeroDivisionError:
            return np.inf
        
    def lagrangian(self):
        G = np.sum(self.rho[0] * self.phi[0] - self.rho[-1] * self.phi[-1])
        constraint = self.nabla_phi - self.c
        L = G + np.sum(self.M * constraint)
        L_tau = L + self.tau/2 * np.sum(constraint**2)
        return L,L_tau
        
    def solve(self, maxiter=100):
        criteria=[]
        LL=[]
        iterator = range(maxiter)
        for i in iterator:
            L,L_tau = self.lagrangian()
            LL.append(L_tau)
            self.poisson_step()
            L,L_tau = self.lagrangian()
            LL.append(L_tau)
            self.projection_step_bis()
            L,L_tau = self.lagrangian()
            LL.append(L_tau)
            self.dual_step()
            self.tau = 2*self.tau
            crit = self.criterium()
            res = np.max(np.abs(self.residual()))
            criteria.append(crit)
        return criteria,LL
        
    def plot(self, t=None):
        tt = np.atleast_1d(t)
        n = len(tt)
        cols = math.ceil(math.sqrt(n))
        rows = math.ceil(n / cols)
        fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f"t = {t:.2f}" for t in tt])
        for i, t in enumerate(tt):
            row = i // cols + 1
            col = i % cols + 1
            time_index = int(t * (self.T - 1))
            fig.add_trace(
                go.Contour(
                    z=self.rho[time_index],
                    x=np.arange(self.T) / (self.T - 1),
                    y=np.arange(self.T) / (self.T - 1),
                    colorbar=dict(len=1/rows, y=(1 - (2*row-1)/(2*rows))),
                    showscale=col == cols
                ),
                row=row, col=col
            )
        fig.update_layout(height=300*rows, width=300*cols, plot_bgcolor='white')
        fig.update_xaxes(scaleanchor="y", scaleratio=1)
        fig.update_yaxes(scaleanchor="x", scaleratio=1)
        fig.show()
        fig.write_image('images/ot/benamou_brenier.pdf')

    def last_root(self, a,b,c):
        p = b - a**2/3
        q = a / 27 * (2*a**2 - 9*b) + c
        delta = (p/3)**3 + (q/2)**2
        if delta>0:
            u = np.cbrt(-q/2 + np.sqrt(delta))
            v = np.cbrt(-q/2 - np.sqrt(delta))
            x = u + v - a/3
        elif delta == 0:
            u = np.cbrt(-q/2)
            x = 2*np.abs(u) - a/3
        else:
            u = (-q/2 + 1j*np.sqrt(-delta))**(1/3)
            x = 2*np.real(u) - a/3
        return x


### Ejemplo

In [4]:
def torus_dist2(x, y):
    return np.sum(np.minimum((x - y)**2, (1 - np.abs(x - y))**2))

def create_gaussian_distribution(mesh, mean, sigma):
    return np.exp(-0.5 * np.apply_along_axis(lambda x: torus_dist2(x, mean), 0, mesh) / sigma)

# Parámetros:
N = 30
T = N
sigma = 0.2

# Configuración de la malla:
x = y = np.linspace(0, 1, N)
mesh = np.array(np.meshgrid(x, y))

# Definición de las distribuciones:
mean1, mean2 = 0.5 * np.ones(2), np.ones(2)
mu = create_gaussian_distribution(mesh, mean1, sigma)
nu = create_gaussian_distribution(mesh, mean2, sigma)

# Normalización:
mu /= np.sum(mu)
nu /= np.sum(nu)

In [5]:
# Resolución del problema de transporte:
problem = TransportProblem(mesh, mu, nu, T, tau=100)
criteria, LL = problem.solve(maxiter=30)
t = np.linspace(0,1,6)
problem.plot(t)

  return splu(A).solve
