# <center>Geodésicas e interpolación de McCann</center>

In [1]:
import numpy as np
import ot
import scipy.stats as stats
import plotly.graph_objects as go
from scipy import stats

np.random.seed(42)

## Interpolación entre variables aleatorias gaussiana

### Interpolación de McCann

In [2]:
def gaussian_interpolation(mu, sigma, n_interpol=10):
    support_limit = max(mu) + 3 * max(sigma)
    support = np.linspace(-support_limit, support_limit, 1000)

    fig = go.Figure()

    for t in np.linspace(0, 1, n_interpol):

        # Estadísticos de interpolación:
        mu_interpolation = (1-t) * mu[0] + t * mu[1]
        sigma_interpolation = (1-t) * sigma[0] + t * sigma[1]

        # Densidad y gráfico de interpolación:
        gaussian_interpolation = stats.norm(mu_interpolation, sigma_interpolation)
        density_interpolation = gaussian_interpolation.pdf(support)
        fig.add_trace(go.Scatter(x=support, y=density_interpolation, mode='lines', line=dict(color='gray', width=1), opacity=t))

    # Densidades inicial y final:
    gaussian_x = stats.norm(mu[0], sigma[0])
    gaussian_y = stats.norm(mu[1], sigma[1])
    fig.add_trace(go.Scatter(x=support, y=gaussian_x.pdf(support), mode='lines', line=dict(color='blue'), name='$p_x$'))
    fig.add_trace(go.Scatter(x=support, y=gaussian_y.pdf(support), mode='lines', line=dict(color='red'), name='$p_y$'))

    # Configuraciones gráficas:
    fig.update_layout(
        title='Interpolación entre gaussianas',
        xaxis_title='Soporte', yaxis_title='Densidad',
        height=500, width=1200,
        plot_bgcolor='white', paper_bgcolor='white',
        showlegend=False)
    
    fig.show()
    fig.write_image('images/ot/gaussian_interpolation.pdf')

In [3]:
# Distribuciones gaussianas de origen y destino:
mu_x, sigma_x = -30, 10
mu_y, sigma_y = 30, 5

# Interpolación:
gaussian_interpolation(mu=[mu_x, mu_y], sigma=[sigma_x, sigma_y])

## Interpolación de una mixtura gaussiana

### Generación de datos

In [4]:
def generate_mixture(alpha, mu, var, interval=[0, 1], n=200):

    mixture_size = len(alpha)
    t = np.linspace(*interval, num=n)

    density = 0
    for j in range(mixture_size):
        density += alpha[j] * stats.norm.pdf(t, loc=mu[j], scale=np.sqrt(var[j]))
    
    density /= density.sum()
    return density

# Mixturas:
gmm_x = generate_mixture([0.4, 0.6], [0.6, 0.8], [0.003, 0.005])
gmm_y = generate_mixture([0.2, 0.8], [0.2, 0.4], [0.001, 0.002])

### Interpolación de McCann

In [5]:
def numeric_interpolation(x, y, n_interpol=10):
    
    t = np.linspace(0, 1, num=len(x))
    M = ot.dist(t[:, np.newaxis], t[:, np.newaxis])
    A = np.vstack((x, y)).T

    fig = go.Figure()

    fig.add_trace(go.Scatter(x=t, y=x, mode='lines', line=dict(color='blue'), name='$p_x$'))
    fig.add_trace(go.Scatter(x=t, y=y, mode='lines', line=dict(color='red'), name='$p_y$'))

    for ponderator in np.linspace(0.1, 1, n_interpol):
        ott = ot.barycenter(A, M, 5*1e-4, [1 - ponderator, ponderator])
        fig.add_trace(go.Scatter(x=t, y=ott, mode='lines', line=dict(color='gray', width=1), opacity=ponderator))

    fig.update_layout(
        title='Interpolación numérica',
        xaxis_title='Soporte', yaxis_title='Densidad',
        height=500, width=1200,
        plot_bgcolor='white', paper_bgcolor='white',
        showlegend=False)
    fig.show()
    fig.write_image('images/ot/gmm_interpolation.pdf')

numeric_interpolation(gmm_x, gmm_y)