In [None]:
import os, sys
sys.path.append("..")

import torch
import numpy as np
from matplotlib import pyplot as plt

from src.light_sb import LightSB
from src.distributions import StandardNormalSampler, SwissRollSampler

In [None]:
models = []

for eps in [0.002]:
    EXP_NAME = f'LightSB_Swiss_Roll_EPSILON_{eps}'
    OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)

    D = LightSB(dim=2, n_potentials=500, epsilon=eps,
                sampling_batch_size=128, is_diagonal=True)

    D.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D.pt')))
    
    models.append(D)

In [None]:
X_sampler = StandardNormalSampler(dim=2, device="cpu")
Y_sampler = SwissRollSampler(dim=2, device="cpu")

In [None]:
SEED = 12

torch.manual_seed(SEED); np.random.seed(SEED)

fig, axes = plt.subplots(1, 4, figsize=(15, 3.75), dpi=200)

for ax in axes:
    ax.grid(zorder=-20)

x_samples = X_sampler.sample(2048)
y_samples = Y_sampler.sample(2048)
tr_samples = torch.tensor([[0.0, 0.0], [1.75, -1.75], [-1.5, 1.5], [2, 2]])

tr_samples = tr_samples[None].repeat(3, 1, 1).reshape(12, 2)

axes[0].scatter(x_samples[:, 0], x_samples[:, 1], alpha=0.3, 
                c="g", s=32, edgecolors="black", label = r"Input distirubtion $p_0$")
axes[0].scatter(y_samples[:, 0], y_samples[:, 1], 
                c="orange", s=32, edgecolors="black", label = r"Target distribution $p_1$")

for ax, model in zip(axes[1:], models):
    y_pred = model(x_samples)
    
    ax.scatter(y_pred[:, 0], y_pred[:, 1], 
               c="yellow", s=32, edgecolors="black", label = "Fitted distribution", zorder=1)
    
    trajectory = model.sample_euler_maruyama(tr_samples, 1000).detach().cpu()
    
    ax.scatter(tr_samples[:, 0], tr_samples[:, 1], 
       c="g", s=128, edgecolors="black", label = r"Trajectory start ($x \sim p_0$)", zorder=3)
    
    ax.scatter(trajectory[:, -1, 0], trajectory[:, -1, 1], 
       c="red", s=64, edgecolors="black", label = r"Trajectory end (fitted)", zorder=3)
        
    for i in range(12):
        ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], "black", markeredgecolor="black",
             linewidth=1.5, zorder=2)
        if i == 0:
            ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], "grey", markeredgecolor="black",
                     linewidth=0.5, zorder=2, label=r"Trajectory of $T_{\theta}$")
        else:
            ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], "grey", markeredgecolor="black",
                     linewidth=0.5, zorder=2)
    
for ax, title in zip(axes, titles):
    ax.set_xlim([-2.5, 2.5])
    ax.set_ylim([-2.5, 2.5])
    ax.legend(loc="lower left")

fig.tight_layout(pad=0.1)
