In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

from jax import config
config.update("jax_enable_x64", False)

In [None]:
import OTF, CNF, RealNVP, DatasetGenerator
import numpy as np
import pandas as pd
import optax
import matplotlib.pyplot as plt

from jax import jit, grad, value_and_grad, random
from flax.training import checkpoints

# plotting, move this to seperate notebook
from jax import vmap
import jax.numpy as jnp
from matplotlib import colormaps

# Gen Dataset

In [None]:
PRECISION_PREFIX = 'float32'

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_tuc_logo(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(train[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(train[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(train[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(train[:, 1].max() * 10) / 10 + offset, 
}

## Same MMD, different NLL

In [None]:
EVAL_SOLVER_STEPS = 40

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS, solver_steps=solver_steps)

In [None]:
middle_CNF = 'CNF_l2_float32_tuc-logo_5VF'

In [None]:
params_CNF = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle_CNF + '_bestparams_mmd/checkpoint_20000/checkpoint', None)

In [None]:
MODEL = CNF.CNF(
    input_dim=2,
    hidden_dim=64,
    out_dim=2,
    depth=3,
    num_blocks=5,
    key=key,
    f_theta_cls=CNF.f_theta,
    exact_logp=True,     
    num_steps=20
)

In [None]:
res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

probs_cnf = jnp.exp(
    vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(xy, EVAL_SOLVER_STEPS, params_CNF, False)
)

In [None]:
config.update("jax_enable_x64", True)

In [None]:
PRECISION_PREFIX = 'float64'

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_tuc_logo(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

In [None]:
middle_RealVP = 'RealNVP_nol2_float64_tuc-logo_16Bl'

In [None]:
params_RealVP = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle_RealVP + '_bestparams_mmd/checkpoint_20000/checkpoint', None)

In [None]:
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=16,
    key=key,
    input_dim=2,
    hidden_dim=64,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)

In [None]:
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS)

In [None]:
test_loss, test_inv_error, test_mmd = eval_val(test, params_RealVP, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

probs_realnvp = jnp.exp(MODEL.log_pdf_and_preimage(xy, params_RealVP, False))

In [None]:
middle_OTF = 'OFT_float64_tuc-logo_1pot'

In [None]:
params_OTF = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle_OTF + '_bestparams_mmd/checkpoint_20000/checkpoint', None)

In [None]:
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=64,
        resnet_depth=2,
        rank=20,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=1,
        t0=0.,
        t1=1.,
        num_steps=20
)

In [None]:
probs_otf = jnp.exp(
    -vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(xy, EVAL_SOLVER_STEPS, params_OTF, False)[0]
)

In [None]:
plt.rcParams['axes.labelsize'] = 15
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['xtick.labelsize']= 15
plt.rcParams['ytick.labelsize']= 15

In [None]:
fig, ax = plt.subplots(1,3,figsize=(24,6))
ax[0].imshow((probs_realnvp).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[0].set_title('RealNVP, 16 Coupling Layers \nTest MMD$=0.095\cdot 10^{-3}$, Test NLL$=2.400$')
ax[1].imshow((probs_cnf).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1].set_title('Vanilla CNF + $L_2$, 3 Vector Fields \nTest MMD$=0.106\cdot 10^{-3}$, Test NLL$=2.121$')
ax[2].imshow((probs_otf).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[2].set_title('OTF, 1 Potential \nTest MMD$=0.092\cdot 10^{-3}$, Test NLL$=2.263$')
plt.setp(ax, xlabel='$x_1$ component', ylabel='$x_2$ component');
#plt.tight_layout()

In [None]:
fig.savefig('mmd_comparision_bigger.png', bbox_inches='tight')

# Same NLL, Different MMD

In [None]:
middle_3 = 'CNF_nol2_float32_tuc-logo_3VF'
middle_5 = 'CNF_nol2_float32_tuc-logo_5VF' 
middle_10 = 'CNF_nol2_float32_tuc-logo_10VF'

In [None]:
params_3 = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle_3 + '_bestparams_loss/checkpoint_20000/checkpoint', None)
params_5 = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle_5 + '_bestparams_loss/checkpoint_20000/checkpoint', None)
params_10 = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle_10 + '_bestparams_loss/checkpoint_20000/checkpoint', None)

In [None]:
MODEL_3 = CNF.CNF(
    input_dim=2,
    hidden_dim=64,
    out_dim=2,
    depth=3,
    num_blocks=3,
    key=key,
    f_theta_cls=CNF.f_theta,
    exact_logp=True,     
    num_steps=20
) 
MODEL_5 = CNF.CNF(
    input_dim=2,
    hidden_dim=64,
    out_dim=2,
    depth=3,
    num_blocks=5,
    key=key,
    f_theta_cls=CNF.f_theta,
    exact_logp=True,     
    num_steps=20
)
MODEL_10 = CNF.CNF(
    input_dim=2,
    hidden_dim=64,
    out_dim=2,
    depth=3,
    num_blocks=10,
    key=key,
    f_theta_cls=CNF.f_theta,
    exact_logp=True,     
    num_steps=20
)

In [None]:
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
EVAL_SOLVER_STEPS = 20

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS, solver_steps=solver_steps)

In [None]:
MODEL = MODEL_3
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
test_loss, test_inv_error, test_mmd = eval_val(test, params_3, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

MODEL = MODEL_5
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
test_loss, test_inv_error, test_mmd = eval_val(test, params_5, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

MODEL = MODEL_10
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
test_loss, test_inv_error, test_mmd = eval_val(test, params_10, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')