# <center>Light Schrödinger bridge (Korotin et al., 2023)</center>

In [1]:
import math
import torch
from torch import nn
from torch.distributions import Categorical, Independent, Normal, MixtureSameFamily
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
from sklearn import datasets

## Light Schrödinger bridge

### Módulo

In [2]:
class LightSB(nn.Module):

    def __init__(self, dim=2, n_potentials=5, epsilon=1, sampling_batch_size=1, S_diagonal_init=0.1):
        super().__init__()
        self.dim = dim
        self.n_potentials = n_potentials
        self.sampling_batch_size = sampling_batch_size
        self.epsilon = torch.tensor(epsilon)

        # Parámetros de la mixtura:
        self.log_alpha_raw = nn.Parameter(self.epsilon * torch.log(torch.ones(n_potentials) / n_potentials))
        self.r = nn.Parameter(torch.randn(n_potentials, dim))
        self.S_log_diagonal_matrix = nn.Parameter(torch.log(S_diagonal_init * torch.ones(n_potentials, dim)))

    def init_r_by_samples(self, samples):
        assert samples.shape[0] == self.n_potentials
        self.r.data = torch.clone(samples)

    @torch.no_grad()
    def forward(self, x):
        S = torch.exp(self.S_log_diagonal_matrix)
        
        epsilon = self.epsilon
        log_alpha = (1 / self.epsilon) * self.log_alpha_raw
        
        samples = []
        batch_size = x.shape[0]
        sampling_batch_size = self.sampling_batch_size

        num_sampling_iterations = batch_size // sampling_batch_size if batch_size % sampling_batch_size == 0 else (batch_size // sampling_batch_size) + 1

        for i in range(num_sampling_iterations):
            sub_batch_x = x[sampling_batch_size * i: sampling_batch_size * (i + 1)]

            x_S_x = (sub_batch_x[:, None, :] * S[None, :, :] * sub_batch_x[:, None, :]).sum(dim=-1)
            x_r = (sub_batch_x[:, None, :] * self.r[None, :, :]).sum(dim=-1)
            r_x = self.r[None, :, :] + S[None, :] * sub_batch_x[:, None, :]
            
            exp_argument = (x_S_x + 2 * x_r) / (2 * epsilon) + log_alpha[None, :]
            mix = Categorical(logits=exp_argument)
            comp = Independent(Normal(loc=r_x, scale=torch.sqrt(epsilon * S)[None, :, :]), 1)
            gmm = MixtureSameFamily(mix, comp)

            samples.append(gmm.sample())

        return torch.cat(samples, dim=0)

    def get_drift(self, x, t):
        x = torch.clone(x)
        x.requires_grad = True
        
        epsilon = self.epsilon
        S_diagonal = torch.exp(self.S_log_diagonal_matrix)
        A_diagonal = (t / (epsilon * (1 - t)))[:, None, None] + 1 / (epsilon * S_diagonal)[None, :, :]
        S_log_det = torch.sum(self.S_log_diagonal_matrix, dim=-1)
        A_log_det = torch.sum(torch.log(A_diagonal), dim=-1)
        log_alpha = (1 / self.epsilon) * self.log_alpha_raw
        
        S_inv = 1 / S_diagonal
        A_inv = 1 / A_diagonal
        c = ((1 / (epsilon * (1 - t)))[:, None] * x)[:, None, :] + (self.r / (epsilon * S_diagonal))[None, :, :]
        
        exp_arg = (log_alpha[None, :] - 0.5 * S_log_det[None, :] - 0.5 * A_log_det - 0.5 * ((self.r * S_inv * self.r) / epsilon).sum(dim=-1)[None, :] + 0.5 * (c * A_inv * c).sum(dim=-1))
        lse = torch.logsumexp(exp_arg, dim=-1)
        drift = (-x / (1 - t[:, None]) + epsilon * torch.autograd.grad(lse, x, grad_outputs=torch.ones_like(lse))[0]).detach()
        
        return drift

    def sample_euler_maruyama(self, x, n_steps):
        epsilon = self.epsilon
        t = torch.zeros(x.shape[0])
        dt = 1 / n_steps
        trajectory = [x]
        
        for i in range(n_steps):
            x = x + self.get_drift(x, t) * dt + math.sqrt(dt) * torch.sqrt(epsilon) * torch.randn_like(x)
            t += dt
            trajectory.append(x)
        
        return torch.stack(trajectory, dim=1)

    def get_log_potential(self, x):
        S = torch.exp(self.S_log_diagonal_matrix)

        log_alpha = self.log_alpha_raw / self.epsilon
        
        epsilon = self.epsilon
        mix = Categorical(logits=log_alpha)
        comp = Independent(Normal(loc=self.r, scale=torch.sqrt(epsilon * S)), 1)
        gmm = MixtureSameFamily(mix, comp)
        potential = gmm.log_prob(x) + torch.logsumexp(log_alpha, dim=-1)
        
        return potential

    def get_log_C(self, x):
        S = torch.exp(self.S_log_diagonal_matrix)
 
        log_alpha = (1 / self.epsilon) * self.log_alpha_raw
        
        x_S_x = (x[:, None, :] * S[None, :, :] * x[:, None, :]).sum(dim=-1)
        x_r = (x[:, None, :] * self.r[None, :, :]).sum(dim=-1)
        
        exp_argument = (x_S_x + 2 * x_r) / (2 * self.epsilon) + log_alpha[None, :]
        log_norm_const = torch.logsumexp(exp_argument, dim=-1)
        
        return log_norm_const

### Entrenamiento

In [3]:
def train_lsb(model, optimizer, x_sampler, y_sampler, batch_size=128, n_iters=20000):

    model.init_r_by_samples(y_sampler.sample(model.n_potentials))
    
    try:
        for _ in tqdm(range(n_iters)):
            X0, X1 = x_sampler.sample(batch_size), y_sampler.sample(batch_size)
            
            log_potential = model.get_log_potential(X1)
            log_C = model.get_log_C(X0)
            D_loss = (-log_potential + log_C).mean()
            
            optimizer.zero_grad()
            D_loss.backward()
            optimizer.step()
    
    except KeyboardInterrupt:
        pass

## Ejemplos

### Samplers

In [4]:
class StandardNormalSampler():
    def __init__(self, dim):
        self.dim = dim
        
    def sample(self, batch_size):
        return torch.randn(batch_size, self.dim)
    
class SwissRollSampler():
    def sample(self, batch_size=10):
        batch = datasets.make_swiss_roll(n_samples=batch_size, noise=0.8)[0].astype('float32')[:, [0, 2]] / 7.5
        return torch.tensor(batch)

### Swiss roll

#### Data

In [5]:
toy_x_sampler = StandardNormalSampler(dim=2)
toy_y_sampler = SwissRollSampler()

#### Modelo y entrenamiento

In [6]:
toy_model = LightSB(dim=2, n_potentials=500, epsilon=0.01, sampling_batch_size=128)
toy_optimizer = torch.optim.Adam(toy_model.parameters(), lr=3e-4)

train_lsb(toy_model, toy_optimizer, toy_x_sampler, toy_y_sampler, n_iters=1000)

100%|██████████| 1000/1000 [00:03<00:00, 281.65it/s]


#### Testing

In [None]:
def create_scatter(x, y, mode='markers', **kwargs):
    return go.Scatter(x=x, y=y, mode=mode, **kwargs)

def plot_samples_and_bridge(p0_samples, p1_samples, y_pred, p0_bridge, p1_bridge, trajectory):
    fig = make_subplots(rows=1, cols=2, subplot_titles=('Distribuciones de origen y destino', 'Puente de Schrödinger'))
    
    # Input/target
    fig.add_trace(create_scatter(p0_samples[:, 0], p0_samples[:, 1]), row=1, col=1)
    fig.add_trace(create_scatter(p1_samples[:, 0], p1_samples[:, 1]), row=1, col=1)

    fig.add_trace(create_scatter(y_pred[:, 0], y_pred[:, 1], marker=dict(size=5)), row=1, col=2)

    # Bridge
    for traj in trajectory:
        fig.add_trace(create_scatter(traj[:, 0], traj[:, 1], mode='lines', line=dict(width=0.8, color='gray')), row=1, col=2)
    
    # Fitting
    fig.add_trace(create_scatter(p0_bridge[:, 0], p0_bridge[:, 1], marker=dict(size=10, color='blue')), row=1, col=2)
    fig.add_trace(create_scatter(p1_bridge[:, 0], p1_bridge[:, 1], marker=dict(size=10, color='red')), row=1, col=2)
    
    fig.update_layout(height=600, width=1200, showlegend=False, template='plotly_white')
    fig.update_xaxes(range=[-2.5, 2.5])
    fig.update_yaxes(range=[-2.5, 2.5])
    
    return fig

# Generar muestras
eot_samples, bridge_samples = 2048, 20
p0_samples = toy_x_sampler.sample(eot_samples)
p1_samples = toy_y_sampler.sample(eot_samples)
y_pred = toy_model(p0_samples)

# Generar puente de Schrödinger
p0_bridge = 4 * torch.rand([bridge_samples, 2]) - 2
trajectory = toy_model.sample_euler_maruyama(p0_bridge, 1000)
p1_bridge = trajectory[:, -1, :]

# Crear y mostrar el gráfico
fig = plot_samples_and_bridge(p0_samples, p1_samples, y_pred, p0_bridge, p1_bridge, trajectory)
fig.show()
fig.write_image('images/eot_sbp/light_sbp.pdf')