# Single Precision

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from jax import config
config.update("jax_enable_x64", True)

In [None]:
import RealNVP, DatasetGenerator
import numpy as np
import pandas as pd
import optax
import matplotlib.pyplot as plt

from jax import jit, grad, value_and_grad, random
from flax.training import checkpoints

# plotting, move this to seperate notebook
from jax import vmap
import jax.numpy as jnp
from matplotlib import colormaps

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
def eval_val(val_data, params):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS)

In [None]:
def train_step(train_data, opt_state, params):
    opt_state, params, loss = step(opt_state, params, batch=train_data)
    return opt_state, params, loss

In [None]:
MODEL_PREFIX = 'RealNVP_l2'
PRECISION_PREFIX = 'float64'

In [None]:
def run(train_re, params, opt_state, data_key):
    train_losses, val_losses = [], []
    best_val_loss, best_epoch, best_params = np.inf, None, None
    if report_as_csv:
        report_df = pd.DataFrame({'Epoch': [], 'Training Loss': [], 
                                  'Validation Loss': [], 'Best Epoch': [], 'Best Validation Loss': [],})

    for epoch in range(1, EPOCHS + 1):
        opt_state, params, train_loss = train_step(train_re, opt_state, params)    
        train_losses.append(train_loss)

        val_loss = eval_val(val, params)
        print(f'Epoch {epoch} --- Train Loss: {np.mean(train_loss)}, Val Loss: {val_loss}')
        val_losses.append(val_loss)

        if epoch >= 5:
            val_loss_rolling = np.mean(val_losses[epoch-5:])    
            if val_loss_rolling < best_val_loss:
                best_epoch, best_params, best_val_loss = epoch, params, val_loss

        # resample training data
        if epoch % 25 == 0:    
            data_key = random.fold_in(data_key, epoch)
            train_re = TRAIN_GENERATOR(data_key, training_size, dtype=PRECISION_PREFIX, **TRAIN_GEN_KWARGS)      

        if report_as_csv:    
            report_df.loc[epoch - 1] = [epoch, train_loss, val_loss, best_epoch, best_val_loss]
            if epoch % 10 == 0: 
                report_df.to_csv(report_filename)


    print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
    test_loss = eval_val(test, best_params)
    print(f'Best Model --- Test Loss: {test_loss}')

    return train_losses, val_losses, best_epoch, best_val_loss, params, best_params

## TUC Letters (no Spacing)

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000
spacing = 0

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_tuc_letters(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
    spacing=spacing
)

In [None]:
DATASET_PREFIX = 'letters-nospacing'

In [None]:
TRAIN_GENERATOR = DatasetGenerator.make_tuc_letters_tr
TRAIN_GEN_KWARGS = {'spacing': spacing}

### 1e-5

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=12,
    key=key,
    input_dim=2,
    hidden_dim=32,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_KWARGS = {'alpha': 1e-5}
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_KWARGS = {}

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=2e-3, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = False
report_filename = filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
train_losses, val_losses, best_epoch, best_val_loss, params, best_params = run(train, params, opt_state, data_key)

In [None]:
print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
test_loss = eval_val(test, best_params)
print(f'Best Model --- Test Loss: {test_loss}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
MODEL.key = key

In [None]:
TRAIN_DATA = train
MODEL.params = params

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(TRAIN_DATA[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(TRAIN_DATA[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(TRAIN_DATA[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(TRAIN_DATA[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
# 25L, 512Params, alpha=20
y_final, y_best = MODEL.sample(1000), MODEL.sample(1000, params=best_params)

res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])
probs_final = jnp.exp(MODEL.log_pdf(data=xy))
probs_best = jnp.exp(MODEL.log_pdf(data=xy, params=best_params))

fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].set_title('Best Model')
ax[0,0].scatter(y_best[:, 0], y_best[:, 1], s=5, color='red', label='samples')
ax[0,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,0].legend()

ax[0,1].set_title('Final Model')
ax[0,1].scatter(y_final[:, 0], y_final[:, 1], s=5, color='red', label='samples')
ax[0,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,1].legend()

ax[1,0].set_title('Best Model')
ax[1,0].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,0].legend()

ax[1,1].set_title('Final Model')
ax[1,1].imshow(probs_final.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,1].legend()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].scatter(*y_best.T, s=0.1)
ax[1].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')

### 1e-4

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=12,
    key=key,
    input_dim=2,
    hidden_dim=32,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)

LOSS_FUNC = jit(MODEL._loss_l2)
LOSS_KWARGS = {'alpha': 1e-4}
VAL_LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_KWARGS = {}

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=2e-3, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = False
report_filename = filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
train_losses, val_losses, best_epoch, best_val_loss, params, best_params = run(train, params, opt_state, data_key)

In [None]:
print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
test_loss = eval_val(test, best_params)
print(f'Best Model --- Test Loss: {test_loss}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
MODEL.key = key

In [None]:
TRAIN_DATA = train
MODEL.params = params

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(TRAIN_DATA[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(TRAIN_DATA[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(TRAIN_DATA[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(TRAIN_DATA[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
# 25L, 512Params, alpha=20
y_final, y_best = MODEL.sample(1000), MODEL.sample(1000, params=best_params)

res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])
probs_final = jnp.exp(MODEL.log_pdf(data=xy))
probs_best = jnp.exp(MODEL.log_pdf(data=xy, params=best_params))

fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].set_title('Best Model')
ax[0,0].scatter(y_best[:, 0], y_best[:, 1], s=5, color='red', label='samples')
ax[0,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,0].legend()

ax[0,1].set_title('Final Model')
ax[0,1].scatter(y_final[:, 0], y_final[:, 1], s=5, color='red', label='samples')
ax[0,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,1].legend()

ax[1,0].set_title('Best Model')
ax[1,0].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,0].legend()

ax[1,1].set_title('Final Model')
ax[1,1].imshow(probs_final.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,1].legend()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].scatter(*y_best.T, s=0.1)
ax[1].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')

### 1e-3

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=12,
    key=key,
    input_dim=2,
    hidden_dim=32,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL._loss)
LOSS_KWARGS = {'alpha': 1e-3}
VAL_LOSS_KWARGS = {}

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=2e-3, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = False
report_filename = filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
train_losses, val_losses, best_epoch, best_val_loss, params, best_params = run(train, params, opt_state, data_key)

In [None]:
print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
test_loss = eval_val(test, best_params)
print(f'Best Model --- Test Loss: {test_loss}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
MODEL.key = key

In [None]:
TRAIN_DATA = train
MODEL.params = params

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(TRAIN_DATA[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(TRAIN_DATA[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(TRAIN_DATA[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(TRAIN_DATA[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
# 25L, 512Params, alpha=20
y_final, y_best = MODEL.sample(1000), MODEL.sample(1000, params=best_params)

res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])
probs_final = jnp.exp(MODEL.log_pdf(data=xy))
probs_best = jnp.exp(MODEL.log_pdf(data=xy, params=best_params))

fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].set_title('Best Model')
ax[0,0].scatter(y_best[:, 0], y_best[:, 1], s=5, color='red', label='samples')
ax[0,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,0].legend()

ax[0,1].set_title('Final Model')
ax[0,1].scatter(y_final[:, 0], y_final[:, 1], s=5, color='red', label='samples')
ax[0,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,1].legend()

ax[1,0].set_title('Best Model')
ax[1,0].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,0].legend()

ax[1,1].set_title('Final Model')
ax[1,1].imshow(probs_final.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,1].legend()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].scatter(*y_best.T, s=0.1)
ax[1].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')

### 1e-2

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=12,
    key=key,
    input_dim=2,
    hidden_dim=32,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL._loss)
LOSS_KWARGS = {'alpha': 1e-2}
VAL_LOSS_KWARGS = {}

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=2e-3, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = False
report_filename = filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
train_losses, val_losses, best_epoch, best_val_loss, params, best_params = run(train, params, opt_state, data_key)

In [None]:
print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
test_loss = eval_val(test, best_params)
print(f'Best Model --- Test Loss: {test_loss}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
MODEL.key = key

In [None]:
TRAIN_DATA = train
MODEL.params = params

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(TRAIN_DATA[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(TRAIN_DATA[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(TRAIN_DATA[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(TRAIN_DATA[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
# 25L, 512Params, alpha=20
y_final, y_best = MODEL.sample(1000), MODEL.sample(1000, params=best_params)

res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])
probs_final = jnp.exp(MODEL.log_pdf(data=xy))
probs_best = jnp.exp(MODEL.log_pdf(data=xy, params=best_params))

fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].set_title('Best Model')
ax[0,0].scatter(y_best[:, 0], y_best[:, 1], s=5, color='red', label='samples')
ax[0,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,0].legend()

ax[0,1].set_title('Final Model')
ax[0,1].scatter(y_final[:, 0], y_final[:, 1], s=5, color='red', label='samples')
ax[0,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,1].legend()

ax[1,0].set_title('Best Model')
ax[1,0].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,0].legend()

ax[1,1].set_title('Final Model')
ax[1,1].imshow(probs_final.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,1].legend()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].scatter(*y_best.T, s=0.1)
ax[1].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')

### 1e-1

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=12,
    key=key,
    input_dim=2,
    hidden_dim=32,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL._loss)
LOSS_KWARGS = {'alpha': 1e-1}
VAL_LOSS_KWARGS = {}

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=2e-3, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = False
report_filename = filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
train_losses, val_losses, best_epoch, best_val_loss, params, best_params = run(train, params, opt_state, data_key)

In [None]:
print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
test_loss = eval_val(test, best_params)
print(f'Best Model --- Test Loss: {test_loss}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
MODEL.key = key

In [None]:
TRAIN_DATA = train
MODEL.params = params

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(TRAIN_DATA[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(TRAIN_DATA[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(TRAIN_DATA[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(TRAIN_DATA[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
# 25L, 512Params, alpha=20
y_final, y_best = MODEL.sample(1000), MODEL.sample(1000, params=best_params)

res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])
probs_final = jnp.exp(MODEL.log_pdf(data=xy))
probs_best = jnp.exp(MODEL.log_pdf(data=xy, params=best_params))

fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].set_title('Best Model')
ax[0,0].scatter(y_best[:, 0], y_best[:, 1], s=5, color='red', label='samples')
ax[0,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,0].legend()

ax[0,1].set_title('Final Model')
ax[0,1].scatter(y_final[:, 0], y_final[:, 1], s=5, color='red', label='samples')
ax[0,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,1].legend()

ax[1,0].set_title('Best Model')
ax[1,0].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,0].legend()

ax[1,1].set_title('Final Model')
ax[1,1].imshow(probs_final.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,1].legend()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].scatter(*y_best.T, s=0.1)
ax[1].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')

### 1e-0

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=12,
    key=key,
    input_dim=2,
    hidden_dim=32,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL._loss)
LOSS_KWARGS = {'alpha': 1}
VAL_LOSS_KWARGS = {}

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=2e-3, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = False
report_filename = filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
train_losses, val_losses, best_epoch, best_val_loss, params, best_params = run(train, params, opt_state, data_key)

In [None]:
print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
test_loss = eval_val(test, best_params)
print(f'Best Model --- Test Loss: {test_loss}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
MODEL.key = key

In [None]:
TRAIN_DATA = train
MODEL.params = params

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(TRAIN_DATA[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(TRAIN_DATA[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(TRAIN_DATA[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(TRAIN_DATA[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
# 25L, 512Params, alpha=20
y_final, y_best = MODEL.sample(1000), MODEL.sample(1000, params=best_params)

res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])
probs_final = jnp.exp(MODEL.log_pdf(data=xy))
probs_best = jnp.exp(MODEL.log_pdf(data=xy, params=best_params))

fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].set_title('Best Model')
ax[0,0].scatter(y_best[:, 0], y_best[:, 1], s=5, color='red', label='samples')
ax[0,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,0].legend()

ax[0,1].set_title('Final Model')
ax[0,1].scatter(y_final[:, 0], y_final[:, 1], s=5, color='red', label='samples')
ax[0,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,1].legend()

ax[1,0].set_title('Best Model')
ax[1,0].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,0].legend()

ax[1,1].set_title('Final Model')
ax[1,1].imshow(probs_final.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,1].legend()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].scatter(*y_best.T, s=0.1)
ax[1].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')

### 1e+1

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = RealNVP.RealNVP(
    NVP_net=RealNVP.NVP_l,
    num_blocks=12,
    key=key,
    input_dim=2,
    hidden_dim=32,
    prior_type='gaussian',
    prior_args=None,
    use_dropout=False,
    dropout_proba=None
)
LOSS_FUNC = jit(MODEL._loss_l2)
VAL_LOSS_FUNC = jit(MODEL._loss)
LOSS_KWARGS = {'alpha': 10}
VAL_LOSS_KWARGS = {}

In [None]:
@jit
def step(opt_state, params, batch):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=2e-3, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = False
report_filename = filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
train_losses, val_losses, best_epoch, best_val_loss, params, best_params = run(train, params, opt_state, data_key)

In [None]:
print(f'\nBest Epoch: {best_epoch} --- Val Loss: {best_val_loss}')
test_loss = eval_val(test, best_params)
print(f'Best Model --- Test Loss: {test_loss}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
MODEL.key = key

In [None]:
TRAIN_DATA = train
MODEL.params = params

In [None]:
offset = 0.3
plt_params = {
    'X_MIN': np.floor(TRAIN_DATA[:, 0].min() * 10) / 10 - offset,
    'X_MAX': np.ceil(TRAIN_DATA[:, 0].max() * 10) / 10 + offset,
    'Y_MIN': np.floor(TRAIN_DATA[:, 1].min() * 10) / 10 - offset,
    'Y_MAX': np.ceil(TRAIN_DATA[:, 1].max() * 10) / 10 + offset, 
}

In [None]:
# 25L, 512Params, alpha=20
y_final, y_best = MODEL.sample(1000), MODEL.sample(1000, params=best_params)

res = 500
xx, yy = np.meshgrid(np.linspace(plt_params['X_MIN'], plt_params['X_MAX'], res), 
                     np.linspace(plt_params['Y_MIN'], plt_params['Y_MAX'], res))
xy = np.hstack([e.reshape(-1, 1) for e in [xx, yy]])
probs_final = jnp.exp(MODEL.log_pdf(data=xy))
probs_best = jnp.exp(MODEL.log_pdf(data=xy, params=best_params))

fig, ax = plt.subplots(2,2, figsize=(16,12))
ax[0,0].set_title('Best Model')
ax[0,0].scatter(y_best[:, 0], y_best[:, 1], s=5, color='red', label='samples')
ax[0,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,0].legend()

ax[0,1].set_title('Final Model')
ax[0,1].scatter(y_final[:, 0], y_final[:, 1], s=5, color='red', label='samples')
ax[0,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=1, color='blue', label='training data')
ax[0,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[0,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[0,1].legend()

ax[1,0].set_title('Best Model')
ax[1,0].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,0].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,0].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,0].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,0].legend()

ax[1,1].set_title('Final Model')
ax[1,1].imshow(probs_final.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')
ax[1,1].scatter(TRAIN_DATA[:, 0], TRAIN_DATA[:, 1], s=0.1, color='red', label='samples')
ax[1,1].set_xlim((plt_params['X_MIN'], plt_params['X_MAX']))
ax[1,1].set_ylim((plt_params['Y_MIN'], plt_params['Y_MAX']))
ax[1,1].legend()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(16,6))
ax[0].scatter(*y_best.T, s=0.1)
ax[1].imshow(probs_best.reshape((res, res)),
           origin='lower', extent=(plt_params['X_MIN'], plt_params['X_MAX'], 
                                   plt_params['Y_MIN'], plt_params['Y_MAX']), aspect='auto')