In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from jax import config
config.update("jax_enable_x64", False)

In [None]:
import OTF, CNF, RealNVP, DatasetGenerator
import numpy as np
import pandas as pd
import optax
import matplotlib.pyplot as plt

from jax import jit, grad, value_and_grad, random
from flax.training import checkpoints

# plotting, move this to seperate notebook
from jax import vmap
import jax.numpy as jnp
from matplotlib import colormaps

TODO:
- beispielplots: dynamics, VFs, Pots, grad(pots)
- wohin werden die bereiche im inputspace gemapt

Verteidigung: 
- animation der punkte/ transformation

Arbeit:
- einfluss der precision
- einfluss der schrittweite/anzahl der schritte des ODE solvers

# Load Dataset

In [None]:
PRECISION_PREFIX = 'float32'

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

## TUC Letters s=0

In [None]:
spacing = 0

In [None]:
train, val, test = DatasetGenerator.make_tuc_letters(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
    spacing=spacing
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

## TUC Letters s=40

In [None]:
spacing = 40

In [None]:
train, val, test = DatasetGenerator.make_tuc_letters(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
    spacing=spacing
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

## Checkerboards

In [None]:
train, val, test = DatasetGenerator.make_checkerboard(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

## TUC Logo

In [None]:
train, val, test = DatasetGenerator.make_tuc_logo(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

# Define Model

## RealNVP

### No L2

In [None]:
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=32,
    key=key,
    input_dim=2,
    hidden_dim=64,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
EVAL_SOLVER_STEPS = None

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS)

### L2

In [None]:
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=32,
    key=key,
    input_dim=2,
    hidden_dim=64,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {'alpha': 1e-4}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
EVAL_SOLVER_STEPS = None

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS)

## VCNF

### No L2

In [None]:
MODEL = CNF.CNF(
    input_dim=2,
    hidden_dim=64,
    out_dim=2,
    depth=3,
    num_blocks=5,
    key=key,
    f_theta_cls=CNF.f_theta,
    exact_logp=True,     
    num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
EVAL_SOLVER_STEPS = 40

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS, solver_steps=solver_steps)

### L2

In [None]:
MODEL = CNF.CNF(
    input_dim=2,
    hidden_dim=64,
    out_dim=2,
    depth=3,
    num_blocks=15,
    key=key,
    f_theta_cls=CNF.f_theta,
    exact_logp=True,     
    num_steps=20
)
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {'alpha': 1e-5}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
EVAL_SOLVER_STEPS = 40

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS, solver_steps=solver_steps)

## OTF

In [None]:
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=64,
        resnet_depth=2,
        rank=20,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=10,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
EVAL_SOLVER_STEPS = 40

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS, solver_steps=solver_steps)

# Load Params

In [None]:
#middle = 'OFT_float64_tuc-logo_10pot'
# middle = 'CNF_l2_float32_letters-spacing_5VF'
# middle = 'RealNVP_l2_float32_tuc-logo_32Bl'
middle = 'CNF_l2_float32_tuc-logo_15VF'

In [None]:
params = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle + '_params/checkpoint_20000/checkpoint', None)
best_params_loss = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle + '_bestparams_loss/checkpoint_20000/checkpoint', None)
best_params_mmd = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle + '.csv_bestparams_mmd/checkpoint_20000/checkpoint', None)

In [None]:
train_history = pd.read_csv('CNF_l2_float32_tuc-logo_15VF.csv', sep=',', index_col=0)

# Plot History

In [None]:
fig, ax = plt.subplots(1,2,figsize=(16,6))
ax[0].plot(train_history['Epoch'][1000:], train_history['Training Loss'][1000:], label='Training')
ax[0].plot(train_history['Epoch'][1000:], train_history['Validation Loss'][1000:], label='Validation')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('NLL')
ax[0].legend()

ax[1].semilogy(train_history['Epoch'], train_history['Validation MMD'])
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Validation MMD');

# Test Metrics

In [None]:
best_epoch_mmd = train_history.loc[19999, 'Best Epoch MMD'].astype('int')
best_val_mmd = train_history.loc[19999, 'Best Validation MMD']

best_epoch_loss = train_history.loc[19999, 'Best Epoch Loss'].astype('int')
best_val_loss = train_history.loc[19999, 'Best Validation Loss']

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
best_params_mmd

In [None]:
len(best_params_loss)

# Model-specific Plots

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(train[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(train[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(train[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(train[:, 1].max() * 10) / 10 + offset, 
}

## RealNVP

In [None]:
y_best_loss = MODEL.sample(5000, params=best_params_loss)
y_best_mmd = MODEL.sample(5000, params=best_params_mmd)

In [None]:
res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

probs_best_loss = jnp.exp(MODEL.log_pdf_and_preimage(xy, best_params_loss, False))
probs_best_mmd = jnp.exp(MODEL.log_pdf_and_preimage(xy, best_params_mmd, False))

In [None]:
fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].imshow((probs_best_loss).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(*y_best_loss.T, s=0.1)
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))

ax[0,1].imshow((probs_best_mmd).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(*y_best_mmd.T, s=0.1)
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']));

In [None]:
sample_steps = MODEL.sample_stepwise(5000, params=best_params_loss)

In [None]:
fig, ax = plt.subplots(3,3, figsize=(12,12))

for i in range(0, 9):   
    row, col = i // 3, i % 3
    
    if i > 3:
        ax[row, col].set_title(f'Output of the {i}th flow')
    
    ax[row, col].scatter(*sample_steps[i].T, s=0.5)
    ax[row, col].xaxis.set_tick_params(labelbottom=False)
    ax[row, col].yaxis.set_tick_params(labelleft=False)
    ax[row, col].set_xticks([])
    ax[row, col].set_yticks([])
    ax[row, col].set_xlim((-3,3))
    ax[row, col].set_ylim((-3,3))

ax[0, 0].set_title(f'Samples from Base Distibution (Gaussian)')
ax[0, 1].set_title(f'Output of the 1st flow')
ax[0, 2].set_title(f'Output of the 2nd flow')
ax[1, 0].set_title(f'Output of the 3rd flow')
    
plt.tight_layout()
#plt.savefig("realnvp8_ex.png")

## V-CNF

In [None]:
y_best_loss = MODEL.sample(5000, EVAL_SOLVER_STEPS, params=best_params_loss)
y_best_mmd = MODEL.sample(5000, EVAL_SOLVER_STEPS, params=best_params_mmd)

In [None]:
res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

probs_best_loss = jnp.exp(
    vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(xy, EVAL_SOLVER_STEPS, best_params_loss, False)
)
probs_best_mmd = jnp.exp(
    vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(xy, EVAL_SOLVER_STEPS, best_params_mmd, False)
)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(20,6))
density_nll = ax[0].imshow((probs_best_loss).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
density_mmd = ax[1].imshow((probs_best_mmd).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
fig.colorbar(mappable=density_nll)
fig.colorbar(mappable=density_mmd)
plt.tight_layout();

In [None]:
fig2, ax2 = plt.subplots(1,2, figsize=(20,6))
ax2[0].scatter(*y_best_loss.T, s=0.1)
ax2[0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax2[0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax2[1].scatter(*y_best_mmd.T, s=0.1)
ax2[1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax2[1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']));

### VFs

In [None]:
best_params = best_params_loss

In [None]:
cmap = colormaps['tab20']  #('hsv') #('nipy_spectral')
max_colors = 20
colors = [cmap(color_number / max_colors) for color_number in range(max_colors)]

In [None]:
stepwise_sample = MODEL.sample_with_steps(2000, EVAL_SOLVER_STEPS, params=best_params)
stepwise_sample_interm = MODEL.sample_with_steps(20, EVAL_SOLVER_STEPS, params=best_params, intermed_y=True)

res = 30
xx, yy = np.meshgrid(np.linspace(-3, 3, res), 
                     np.linspace(-3, 3, res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

n_vf = len(MODEL.funcs)

dxdy = [MODEL.dt0 * MODEL.funcs[i].apply(best_params[i], t=MODEL.t0 + (MODEL.t1 - MODEL.t0)/ n_vf * i, y=xy) 
        for i in range(n_vf)]


In [None]:
fig3, ax3 = plt.subplots(2,3,figsize=(16,10))

for i in range(6):
    row = i // 3
    col = i % 3
    ax3[row, col].scatter(*stepwise_sample[i].T, s=0.2)

    if i > 0:
        ax3[row, col].set_title(f'Transformed by VF {i}')
        for j_sample in range(20):
            x,y = stepwise_sample_interm[i][j_sample].T
            ax3[row, col].quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1], 
                                scale_units='xy',
                                angles='xy', 
                                scale=1., 
                                color=colors[j_sample])
        
ax3[0, 0].set_title(f'Base Distribution')
ax3[0, 0].scatter(*stepwise_sample_interm[0].T, c=colors)
plt.tight_layout()
plt.setp(ax3, xlim=(-3,3), ylim=(-3,3));

Fuer 1 VF (L2, NLL stopping) sieht man wieso das mitunter nicht so eine sinnvoll darstellung ist. Starke explizite zeitabhaengigkeit

In [None]:
len(dxdy)

In [None]:
fig4, ax4 = plt.subplots(2,3,figsize=(16,10))

for i in range(5):
    row = i // 3
    col = i % 3
    ax4[row, col].quiver(*xy.T,*dxdy[i].T)
    ax4[row, col].set_title(f' VF {i+1}')
        
fig4.delaxes(ax4[1, 2])
plt.tight_layout()
plt.setp(ax, xlim=(-3,3), ylim=(-3,3));

'''
for j_sample in range(20):
    x,y = stepwise_sample_interm[1][j_sample].T
    ax4[0, 0].quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1], 
                        scale_units='xy',
                        angles='xy', 
                        scale=1., 
                        color=colors[j_sample])''';

## OTF

In [None]:
y_best_loss = MODEL.sample(5000,EVAL_SOLVER_STEPS,  params=best_params_loss)
y_best_mmd = MODEL.sample(5000, EVAL_SOLVER_STEPS, params=best_params_mmd)

In [None]:
res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

probs_best_loss = jnp.exp(
    -vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(xy, EVAL_SOLVER_STEPS, best_params_loss, False)[0]
)
probs_best_mmd = jnp.exp(
    -vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(xy, EVAL_SOLVER_STEPS, best_params_mmd, False)[0]
)

In [None]:
# 64b 10pot plot ansehen

In [None]:
fig, ax = plt.subplots(1,2, figsize=(20,6))
density_nll = ax[0].imshow((probs_best_loss).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
density_mmd = ax[1].imshow((probs_best_mmd).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
fig.colorbar(mappable=density_nll)
fig.colorbar(mappable=density_mmd)
plt.tight_layout();

In [None]:
fig, ax = plt.subplots(1,2, figsize=(20,6))
density_nll = ax[0].imshow((probs_best_loss).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
density_mmd = ax[1].imshow((probs_best_mmd).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
fig.colorbar(mappable=density_nll)
fig.colorbar(mappable=density_mmd)
plt.tight_layout();

In [None]:
fig2, ax2 = plt.subplots(1,2, figsize=(20,6))
ax2[0].scatter(*y_best_loss.T, s=0.1)
ax2[0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax2[0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax2[1].scatter(*y_best_mmd.T, s=0.1)
ax2[1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax2[1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']));

In [None]:
best_params = best_params_loss

In [None]:
cmap = colormaps['tab20']  #('hsv') #('nipy_spectral')
max_colors = 20
colors = [cmap(color_number / max_colors) for color_number in range(max_colors)]

In [None]:
stepwise_sample = MODEL.sample_with_steps(1000, EVAL_SOLVER_STEPS, params=best_params)
stepwise_sample_interm = MODEL.sample_with_steps(20, EVAL_SOLVER_STEPS, params=best_params, intermed_y=True)

res = 30
xx, yy = np.meshgrid(np.linspace(-3, 3, res), 
                     np.linspace(-3, 3, res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

n_pot = len(MODEL.funcs)

dxdy = [-MODEL.dt0 * vmap(MODEL.forward_dynamics, (None, 0, None), 0)(MODEL.t1 - (MODEL.t1 - MODEL.t0) / n_pot * i, 
                                                                      xy, 
                                                                      (MODEL.funcs[i], best_params[i])) for i in range(len(MODEL.funcs))]

In [None]:
fig3, ax3 = plt.subplots(2,3,figsize=(16,10))
for i in range(6):
    row = i // 3
    col = i % 3
    ax3[row, col].scatter(*stepwise_sample[i].T, s=0.2)

    if i > 0:
        ax3[row, col].set_title(f' VF {i}')
        for j_sample in range(20):
            x,y = stepwise_sample_interm[i][j_sample].T
            ax3[row, col].quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1], 
                                scale_units='xy',
                                angles='xy', 
                                scale=1., 
                                color=colors[j_sample])
        
ax3[0, 0].set_title(f'Base Distribution')
ax3[0, 0].scatter(*stepwise_sample_interm[0].T, c=colors)
#fig3.delaxes(ax3[1, 0])
plt.tight_layout()
plt.setp(ax3, xlim=(-3,3), ylim=(-3,3));

In [None]:
fig4, ax4 = plt.subplots(1,3,figsize=(18,6))

for i in range(3):
    row = i // 3
    col = i % 3
    #ax4[row, col].quiver(*xy.T,*dxdy[i].T)
    #ax4[row, col].set_title(fr'$-\nabla \Phi_{i}$')
    ax4[col].quiver(*xy.T,*dxdy[i].T)
    ax4[col].set_title(fr'$-\nabla \Phi_{i}$')
        \
#fig4.delaxes(ax4[1, 2])
plt.tight_layout()
plt.setp(ax, xlim=(-3,3), ylim=(-3,3));

In [None]:
npot = 3
res = 500
xx, yy = np.meshgrid(np.linspace(-3, 3, res), 
                     np.linspace(-3, 3, res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

fig5, ax5 = plt.subplots(3, npot,figsize=(6 * npot,6 * 3))

for i in range(npot):
    t_start = MODEL.t1 * (npot - i) / npot
    s = jnp.vstack((xy.T, t_start * jnp.ones(len(xy)))).T
    E = -vmap(MODEL.funcs[i].apply, (None, 0), 0)(best_params[i], s)
    ax5[0, i].contourf(xx, yy, E.reshape((res,res)), levels=30, cmap='coolwarm')
    ax5[0, i].set_title(fr'$T={t_start:.2f}$')
    
    
    delta = MODEL.t1 / npot
    s = jnp.vstack((xy.T, (t_start - delta / 3) * jnp.ones(len(xy)))).T
    E = -vmap(MODEL.funcs[i].apply, (None, 0), 0)(best_params[i], s)
    ax5[1, i].contourf(xx, yy, E.reshape((res,res)), levels=30, cmap='coolwarm')
    ax5[1, i].set_title(fr'$T={(t_start - delta / 3):.2f}$')
    
    s = jnp.vstack((xy.T, (t_start - 2 * delta / 3) * jnp.ones(len(xy)))).T
    E = -vmap(MODEL.funcs[i].apply, (None, 0), 0)(best_params[i], s)
    ax5[2, i].contourf(xx, yy, E.reshape((res,res)), levels=30, cmap='coolwarm')
    ax5[2, i].set_title(fr'$T={(t_start - 2 * delta / 3):.2f}$')
    
plt.tight_layout()

In [None]:
plt.scatter(*train.T)
plt.vlines(x=-0.7, ymin=-2, ymax=2)
plt.vlines(x=0.6, ymin=-2, ymax=2)

In [None]:
labels = (train[:, 0] >= -0.7).astype(int)

In [None]:
labels += (train[:, 0] >= 0.6).astype(int)

In [None]:
plt.scatter(*train.T, c=labels)

In [None]:
_, _, _, preimg = vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(train, EVAL_SOLVER_STEPS, best_params, True)

In [None]:
plt.scatter(*preimg.T, c=labels)

## Plot Logo

In [None]:
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

In [None]:
values = [1, 30, 40]
colors = [(255, 255, 255), (1, 93, 77), (0, 93, 77)]
norm = plt.Normalize(min(values), max(values))
cmap = LinearSegmentedColormap.from_list(
       'logo', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip(values, colors)])

In [None]:
fig, ax = plt.subplots(1,1, figsize=(16,12))
density_nll = ax.imshow((probs_best_loss).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), 
                        aspect='auto',
                       cmap=cmap)
plt.axis('off')
fig.savefig("logo_flow_tight.png", bbox_inches='tight')

In [None]:
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

In [None]:
values = [1, 30, 40]
colors = [(255, 255, 255), (1, 0, 0), (0, 0, 0)]
norm = plt.Normalize(min(values), max(values))
cmap = LinearSegmentedColormap.from_list(
       'logo', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip(values, colors)])

In [None]:
fig, ax = plt.subplots(1,1, figsize=(16,12))
density_nll = ax.imshow((probs_best_loss).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), 
                        aspect='auto',
                       cmap=cmap)
plt.axis('off')
fig.savefig("logo_flow_tight_black.png", bbox_inches='tight')

In [None]:
normal = random.normal(key, shape=(10000,2))
transformed = test[random.choice(key, a=jnp.arange(len(test)), shape=(10000,), replace=False)]

In [None]:
plt.rcParams['axes.labelsize'] = 15
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['xtick.labelsize']= 15
plt.rcParams['ytick.labelsize']= 15

In [None]:
fig, ax = plt.subplots(1,2,figsize=(18,6))
ax[0].scatter(*transformed.T, s=3.)
ax[0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0].set_title(r'Samples from $X$')
plt.setp(ax[0], xlabel='$x_1$ component', ylabel='$x_2$ component')

ax[1].scatter(*normal.T, s=3.)
ax[1].set_title(r'Samples from $Z$')
ax[1].set_xlim((-4.2,4.2))
ax[1].set_ylim((-4.2,4.2))
plt.setp(ax[1], xlabel='$z_1$ component', ylabel='$z_2$ component')
plt.tight_layout(w_pad=15)

In [None]:
fig.savefig('transformation.png', bbox_inches='tight')

In [None]:
plt.scatter(*normal.T, s=3.)

In [None]:
len(transformed)

In [None]:
plt.scatter(*transformed.T, s=3.)