# <center>Algoritmo de Sinkhorn</center>

In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats as sps
import plotly

## Algoritmo de Sinkhorn

In [None]:
def sinkhorn(x, y, C, epsilon=0.01, barycentric_projection=False, n_iter=1000, plot=True):

    K = np.exp(-C / epsilon)
    N = [len(x), len(y)]
    Err_p, Err_q = [], []

    v = np.ones(N[1])
    for _ in range(n_iter):

        # Primera iteración:
        u = x / np.dot(K, v)
        r = v * np.dot(np.transpose(K), u)
        Err_q.append(np.linalg.norm(r - y, 1))

        # Segunda iteración:
        v = y / np.dot(np.transpose(K), u)
        s = u * np.dot(K, v)
        Err_p.append(np.linalg.norm(s - x, 1))
        
    P = np.dot(np.dot(np.diag(u), K), np.diag(v))

    if plot:

        # Gráfico de error:
        fig = make_subplots(rows=1, cols=2, subplot_titles=('', 'Plan de Kantorovich entrópico'))
        fig.add_trace(go.Scatter(y=np.log(np.asarray(Err_p) + 1e-10), mode='lines', name='||P 1 - a||', line=dict(width=2, color='blue')), row=1, col=1)
        fig.add_trace(go.Scatter(y=np.log(np.asarray(Err_q) + 1e-10), mode='lines', name='||P^T 1 - b||', line=dict(width=2, color='red')), row=1, col=1)

        # Proyección baricéntrica:
        if barycentric_projection:
            assert N[0] == N[1], 'Las muestras deben ser del mismo tamaño para realizar proyección baricéntrica.'
            fig.add_trace(go.Heatmap(z=np.log(P + 1e-5), colorscale='Viridis'), row=1, col=2)
            t = np.arange(0, N[0]) / N[0]
            s = np.dot(K, v * t) * u / x
            fig.add_trace(go.Scatter(x=s * N[0], y=t * N[0], mode='lines', name='Proyección baricéntrica', line=dict(color='red', width=4)), row=1, col=2)
        else:
            fig.add_trace(go.Heatmap(z=np.clip(P, 0, np.min(1 / np.asarray(N)) * .3), colorscale='Viridis'), row=1, col=2)

        # Configuraciones gráficas:
        fig.update_layout(
            title_text=f'Algoritmo de Sinkhorn para ε = {epsilon}',
            height=500, width=1200, showlegend=True,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
            font=dict(family="Arial", size=12),
            plot_bgcolor='white', paper_bgcolor='white'
        )
        fig.update_xaxes(title_text='Iteración', row=1, col=1, showgrid=True, gridwidth=1, gridcolor='lightgray')
        fig.update_yaxes(title_text='Log-error', row=1, col=1, showgrid=True, gridwidth=1, gridcolor='lightgray')
        fig.update_xaxes(showgrid=False, showticklabels=False, row=1, col=2)
        fig.update_yaxes(showgrid=False, showticklabels=False, row=1, col=2)
        fig.show()
    
    return P

## Caso discreto

### Generación de datos

In [3]:
# Gaussianas isotrópicas:
n_points_x, n_points_y = 30, 40
x = np.random.randn(2, n_points_x)
y = np.random.randn(2, n_points_y)

fig = go.Figure()
fig.add_trace(go.Scatter(x=x[0], y=x[1], mode='markers', name='x', marker=dict(size=10, color='blue', line=dict(color='black', width=1))))
fig.add_trace(go.Scatter(x=y[0], y=y[1], mode='markers', name='y', marker=dict(size=10, color='red', line=dict(color='black', width=1))))

fig.update_layout(title='Ubicación para el transporte óptimo', xaxis_title='X', yaxis_title='Y', width=500, height=500, plot_bgcolor='white')
fig.update_yaxes(scaleanchor="x", scaleratio=1)

fig.show()

### Transporte óptimo entrópico

In [4]:
# Matriz de costo:
C = np.sum(x**2, axis=0)[:, np.newaxis] + np.sum(y**2, axis=0) - 2 * x.T @ y

# Histogramas objetivos (uniformes):
a = np.ones(n_points_x)/n_points_x
b = np.ones(n_points_y)/n_points_y

# Algoritmo de Sinkhorn:
for epsilon in [1, 0.1, 0.01, 0.005]:
    sinkhorn(a, b, C, epsilon=epsilon, n_iter=100)

### Visualización del plan de transporte

In [None]:
def plot_optimal_transport(x, y, P, epsilon):
    fig = go.Figure()

    # Rectas:
    max_width = 5
    for i in range(P.shape[0]):
        for j in range(P.shape[1]):
            fig.add_trace(go.Scatter(x=[x[0, i], y[0, j]], y=[x[1, i], y[1, j]],
                                        mode='lines', showlegend=False,
                                        line=dict(width=P[i, j] * max_width / np.max(P), 
                                                color='rgba(100,100,100,0.3)')))
            
    # Puntos de origen y destino:
    fig.add_trace(go.Scatter(x=x[0], y=x[1], mode='markers', name='x', marker=dict(size=10, color='blue', line=dict(color='black', width=1))))
    fig.add_trace(go.Scatter(x=y[0], y=y[1], mode='markers', name='y',  marker=dict(size=10, color='red', line=dict(color='black', width=1))))

    fig.update_layout(title=f'Plan de transporte óptimo para ε = {epsilon}',
                      xaxis_title='X', yaxis_title='Y',
                      width=700, height=700, plot_bgcolor='white',
                      xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False, title=''),
                      yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False, title=''),)
    fig.update_yaxes(scaleanchor="x", scaleratio=1)
    fig.show()

In [6]:
for epsilon in [1, 0.1, 0.01, 0.005]:
    P = sinkhorn(a, b, C, epsilon=epsilon, n_iter=100, plot=False)
    plot_optimal_transport(x, y, P, epsilon)

## Caso continuo

### Generación de datos

In [None]:
def generate_mixture(alpha, mu, var, interval=[0,1], n=200):
    t = np.linspace(*interval, num=n)
    density = sum(a * sps.norm.pdf(t, loc=m, scale=np.sqrt(v)) for a, m, v in zip(alpha, mu, var))
    return density / density.sum()

N = 200
#gmm_x = generate_mixture([0.3, 0.7], [0.2, 0.4], [0.005, 0.01], n=N)
#gmm_y = generate_mixture([0.4, 0.6], [0.6, 0.8], [0.008, 0.005], n=N)

gmm_x = generate_mixture(
    alpha=[0.3, 0.2, 0.3, 0.2],
    mu=[0.1, 0.4, 0.7, 0.9],
    var=[0.002, 0.005, 0.003, 0.001],
    n=N
)
gmm_y = generate_mixture(
    alpha=[0.1, 0.2, 0.15, 0.25, 0.2, 0.1],
    mu=[0.1, 0.3, 0.5, 0.6, 0.8, 0.95],
    var=[0.002, 0.008, 0.005, 0.01, 0.006, 0.001],
    n=N
)

vmin = 0.02
gmm_x = (gmm_x + np.max(gmm_x) * vmin) / np.sum(gmm_x + np.max(gmm_x) * vmin)
gmm_y = (gmm_y + np.max(gmm_y) * vmin) / np.sum(gmm_y + np.max(gmm_y) * vmin)

t = np.linspace(0, 1, num=len(gmm_x))

fig = go.Figure()
fig.add_trace(go.Scatter(x=t, y=gmm_x, fill='tozeroy', fillcolor='rgba(0,0,255,0.2)', line_color='blue', name='x'))
fig.add_trace(go.Scatter(x=t, y=gmm_y, fill='tozeroy', fillcolor='rgba(255,0,0,0.2)', line_color='red', name='y'))

fig.update_layout(
    title='Distribuciones de origen y destino',
    xaxis_title='Soporte',
    yaxis_title='Densidad',
    height=500, width=1200,
    plot_bgcolor='white', paper_bgcolor='white')

fig.show()

### Transporte óptimo entrópico

In [8]:
# Matriz de costo:
[Y,X] = np.meshgrid(t, t)
C = (X-Y)**2

# Algoritmo de Sinkhorn:
for epsilon in [1, 0.1, 0.01, 0.001, 0.0005]:
    sinkhorn(gmm_x, gmm_y, C, epsilon=epsilon, barycentric_projection=True, n_iter=1000)