In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from jax import config
config.update("jax_enable_x64", False)

In [None]:
import OTF, CNF, RealNVP, DatasetGenerator
import numpy as np
import pandas as pd
import optax
import matplotlib.pyplot as plt

from jax import jit, grad, value_and_grad, random
from flax.training import checkpoints

# plotting, move this to seperate notebook
from jax import vmap
import jax.numpy as jnp
from matplotlib import colormaps

In [None]:
jnp.exp(17)

# Gen Dataset

In [None]:
PRECISION_PREFIX = 'float32'

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_tuc_logo(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

# Same NLL, Different MMD

In [None]:
middle = 'CNF_l2_float32_tuc-logo_1VF'

In [None]:
params_best = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle + '_bestparams_loss/checkpoint_20000/checkpoint', None)
params_final = checkpoints.restore_checkpoint('checkpoints/finalfinal/' + middle + '_params/checkpoint_20000/checkpoint', None)

In [None]:
MODEL = CNF.CNF(
    input_dim=2,
    hidden_dim=64,
    out_dim=2,
    depth=3,
    num_blocks=1,
    key=key,
    f_theta_cls=CNF.f_theta,
    exact_logp=True,     
    num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
EVAL_SOLVER_STEPS = 20

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS, solver_steps=solver_steps)

In [None]:
train_history = pd.read_csv('CNF_l2_float32_tuc-logo_1VF.csv', sep=',', index_col=0)

In [None]:
plt.rcParams['axes.labelsize'] = 15
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['xtick.labelsize']= 15
plt.rcParams['ytick.labelsize']= 15

In [None]:
fig, ax = plt.subplots(1,2,figsize=(16,6))
ax[0].plot(train_history['Epoch'][0:] / 1000, train_history['Training Loss'][0:], label='Training')
ax[0].plot(train_history['Epoch'][0:] / 1000, train_history['Validation Loss'][0:], label='Validation')
ax[0].set_xlabel('Epoch / 1000')
ax[0].set_ylabel('NLL')
ax[0].legend()

ax[1].semilogy(train_history['Epoch']/ 1000, train_history['Validation MMD'])
ax[1].set_xlabel('Epoch / 1000')
ax[1].set_ylabel('Validation MMD');

In [None]:
fig.savefig('discretization_metrics.png', bbox_inches='tight')

In [None]:
test_loss, test_inv_error, test_mmd = eval_val(test, params_best, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
test_loss, test_inv_error, test_mmd = eval_val(test, params_best, 40)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
ll, preimage = vmap(MODEL.log_pdf_and_preimage,  (0, None, None, None), 0)(test, EVAL_SOLVER_STEPS, params_best, True)

In [None]:
nll = -ll

In [None]:
fig, ax = plt.subplots(1,1)
ax.hist(nll, bins=50, alpha=0.8, edgecolor = "black",)
ax.set_xlabel("Test NLL")
ax.set_ylabel("Count");

In [None]:
ll, preimage = vmap(MODEL.log_pdf_and_preimage,  (0, None, None, None), 0)(test, 2 * EVAL_SOLVER_STEPS, params_best, True)

In [None]:
nll = -ll

In [None]:
fig, ax = plt.subplots(1,1)
ax.hist(nll, bins=50, alpha=0.8, edgecolor = "black",)
ax.set_xlabel("Test NLL")
ax.set_ylabel("Count");

In [None]:
ll, preimage = vmap(MODEL.log_pdf_and_preimage,  
                    (0, None, None, None), 0)(test, EVAL_SOLVER_STEPS, params_best, True)
nll = -ll

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].hist(nll, bins=50, alpha=0.8, edgecolor = "black",)
ax[0].set_xlabel("Test NLL")
ax[0].set_ylabel("Count");
ax[0].set_title(r"Mean Test NLL$=0.419$")

ll, preimage = vmap(MODEL.log_pdf_and_preimage,  
                    (0, None, None, None), 0)(test, 2 * EVAL_SOLVER_STEPS, params_best, True)
nll = -ll
ax[1].hist(nll, bins=50, alpha=0.8, edgecolor = "black",)
ax[1].set_xlabel("Test NLL")
ax[1].set_ylabel("Count");
ax[1].set_title(r"Mean Test NLL$=2.977$")

In [None]:
fig.savefig('discretization_hist_bigger.png', bbox_inches='tight')

In [None]:
ll, preimage = vmap(MODEL.log_pdf_and_preimage,  (0, None, None, None), 0)(test, EVAL_SOLVER_STEPS, params_best, True)

In [None]:
nll = -ll

In [None]:
faulty = test[nll < -7]
not_faulty = test[nll >= -7]

In [None]:
nll[nll >= -7].mean()

In [None]:
faulty_pre = preimage[nll < -7]
not_faulty_pre = preimage[nll >= -7]

In [None]:
plt.rcParams['axes.labelsize'] = 15
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['xtick.labelsize']= 15
plt.rcParams['ytick.labelsize']= 15

In [None]:
fig, ax = plt.subplots(1,2,figsize=(16,6))

ax[0].scatter(*not_faulty.T, s=1.)
ax[0].scatter(*faulty.T, c='red', s=1.)

ax[1].scatter(*not_faulty_pre.T, s=1.)
ax[1].scatter(*faulty_pre.T, s=1., c='red')

plt.setp(ax, xlabel='$x_1$ component', ylabel='$x_2$ component')

In [None]:
fig.savefig('discretization_scatter.png', bbox_inches='tight')

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(train[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(train[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(train[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(train[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
y_best = MODEL.sample(5000, params=params_best)
y_final = MODEL.sample(5000, params=params_final)

In [None]:
res = 423
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

probs_best = (
    vmap(MODEL.log_pdf_and_preimage, (0, None, None, None), 0)(xy, 20, params_best, False)
)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(10,6))
density_best = ax.imshow((probs_best).reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
fig.colorbar(mappable=density_best)
plt.tight_layout();
plt.setp(ax, xlabel='$x_1$ component', ylabel='$x_2$ component')

In [None]:
fig.savefig('discretization_logpdf.png', bbox_inches='tight')

In [None]:
cmap = colormaps['tab20']  #('hsv') #('nipy_spectral')
max_colors = 20
colors = [cmap(color_number / max_colors) for color_number in range(max_colors)]

In [None]:
stepwise_sample = MODEL.sample_with_steps(1000, params=params_best)
stepwise_sample_interm = MODEL.sample_with_steps(20, params=params_best, intermed_y=True)

res = 30
xx, yy = np.meshgrid(np.linspace(-3, 3, res), 
                     np.linspace(-3, 3, res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])

n_vf = len(MODEL.funcs)

dxdy = [MODEL.dt0 * MODEL.funcs[i].apply(params_best[i], t=MODEL.t0 + (MODEL.t1 - MODEL.t0)/ n_vf * i, y=xy) 
        for i in range(n_vf)]

In [None]:
fig3, ax3 = plt.subplots(1,3,figsize=(24,6))
ax3[0].set_title(f'Base Distribution')
ax3[0].scatter(*stepwise_sample_interm[0].T, c=colors)
ax3[0].scatter(*stepwise_sample[0].T, s=0.2)
ax3[1].scatter(*stepwise_sample[1].T, s=0.2)

ax3[1].set_title(f'Transformed by VF {1}')
for j_sample in range(20):
    x,y = stepwise_sample_interm[1][j_sample].T
    ax3[1].quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1], 
                        scale_units='xy',
                        angles='xy', 
                        scale=1., 
                        color=colors[j_sample])

#ax3[2].quiver(*xy.T,*dxdy[0].T)
ax3[2].set_title(f' Vector Field')
plt.tight_layout()
plt.setp(ax3, xlim=(-2.5, 2.5), ylim=(-2.5,2.5));