Single Precision

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from jax import config
config.update("jax_enable_x64", False)

In [None]:
import OTF, DatasetGenerator
import numpy as np
import pandas as pd
import optax
import matplotlib.pyplot as plt

from jax import jit, grad, value_and_grad, random
from flax.training import checkpoints

# plotting, move this to seperate notebook
from jax import vmap
import jax.numpy as jnp
from matplotlib import colormaps

In [None]:
@jit
def step(opt_state, params, batch, solver_steps):
    # from optax doc: https://optax.readthedocs.io/en/latest/gradient_accumulation.html
    loss, grads = value_and_grad(LOSS_FUNC)(params, batch, **LOSS_KWARGS, solver_steps=solver_steps)
    updates, opt_state = gradient_transform.update(grads, opt_state, params=params) 
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

In [None]:
def eval_val(val_data, params, solver_steps):
    return VAL_LOSS_FUNC(params=params, batch=val_data, **VAL_LOSS_KWARGS, solver_steps=solver_steps)

In [None]:
def train_step(train_data, opt_state, params, solver_steps):
    opt_state, params, loss = step(opt_state, params, batch=train_data, solver_steps=solver_steps)
    return opt_state, params, loss

In [None]:
MODEL_PREFIX = 'OFT'
PRECISION_PREFIX = 'float32'

In [None]:
SOLVER_STEPS = 20
EVAL_SOLVER_STEPS = 20

In [None]:
def run(train_re, params, opt_state, data_key):
    train_losses  = []
    val_losses, val_inv_errors, val_mmds = [], [], []
    best_val_mmd, best_epoch_mmd, best_params_mmd = np.inf, None, None
    best_val_loss, best_epoch_loss, best_params_loss = np.inf, None, None
    
    if report_as_csv:
        report_df = pd.DataFrame({'Epoch': [], 
                                  'Training Loss': [], 
                                  'Validation Loss': [], 
                                  'Validation Inverse Error': [],
                                  'Validation MMD': [],
                                  'Best Epoch Loss': [], 
                                  'Best Epoch MMD': [], 
                                  'Best Validation Loss': [],
                                  'Best Validation MMD': [],})

    for epoch in range(1, EPOCHS + 1):
        opt_state, params, train_loss = train_step(train_re, opt_state, params, SOLVER_STEPS)    
        train_losses.append(train_loss)

        val_loss, val_inv_error, val_mmd  = eval_val(val, params, EVAL_SOLVER_STEPS)
        print(f'Epoch {epoch} --- Train Loss: {np.mean(train_loss)}, Val Loss: {val_loss}, Val Inv Error: {val_inv_error}, Val MMD: {val_mmd}')
        val_losses.append(val_loss)
        val_inv_errors.append(val_inv_error)
        val_mmds.append(val_mmd)

        if epoch >= 5:
            val_loss_rolling = np.mean(val_losses[epoch-5:])    
            if val_loss_rolling < best_val_loss:
                best_epoch_loss, best_params_loss, best_val_loss = epoch, params, val_loss
                
            val_mmd_rolling = np.mean(val_mmds[epoch-5:])    
            if val_mmd_rolling < best_val_mmd:
                best_epoch_mmd, best_params_mmd, best_val_mmd = epoch, params, val_mmd        
        

        # resample training data
        if epoch % 25 == 0:    
            data_key = random.fold_in(data_key, epoch)
            train_re = TRAIN_GENERATOR(data_key, training_size, dtype=PRECISION_PREFIX, **TRAIN_GEN_KWARGS)      

        if report_as_csv:    
            report_df.loc[epoch - 1] = [epoch, train_loss,
                                        val_loss, val_inv_error, val_mmd, 
                                        best_epoch_loss, best_epoch_mmd, best_val_loss, best_val_mmd]
            if epoch % 10 == 0: 
                report_df.to_csv(report_filename)


    print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
    test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
    print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')
    
    print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
    test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
    print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')
    
    ret = (train_losses, val_losses, val_mmds, val_inv_errors, 
           best_epoch_mmd, best_val_mmd, best_params_mmd, 
           best_epoch_loss, best_val_loss, best_params_loss, params)

    return ret

# TUC Letters (no Spacing)

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000
spacing = 0

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_tuc_letters(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
    spacing=spacing
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

In [None]:
DATASET_PREFIX = 'letters-nospacing'

In [None]:
TRAIN_GENERATOR = DatasetGenerator.make_tuc_letters_tr
TRAIN_GEN_KWARGS = {'spacing': spacing}

## 1 Potential

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=1,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 3 Potential

In [None]:
NPOT_PREFIX = '3pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=3,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 5 Potential

In [None]:
NPOT_PREFIX = '5pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=5,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

# TUC Letters (with Spacing)

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000
spacing = 40

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_tuc_letters(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
    spacing=spacing
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

In [None]:
DATASET_PREFIX = 'letters-spacing'

In [None]:
TRAIN_GENERATOR = DatasetGenerator.make_tuc_letters_tr
TRAIN_GEN_KWARGS = {'spacing': spacing}

## 1 Potential

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=1,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 3 Potential

In [None]:
NPOT_PREFIX = '3pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=3,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 5 Potential

In [None]:
NPOT_PREFIX = '5pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=5,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample} 

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

# Checkerboard

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_checkerboard(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

In [None]:
DATASET_PREFIX = 'checkerboards'

In [None]:
TRAIN_GENERATOR = DatasetGenerator.make_checkerboard_tr
TRAIN_GEN_KWARGS = {}

## 1 Potential

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=1,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 3 Potential

In [None]:
NPOT_PREFIX = '3pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=3,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 5 Potential

In [None]:
NPOT_PREFIX = '5pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=32,
        resnet_depth=2,
        rank=10,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=5,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample} 

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

# TUC Logo

In [None]:
key = random.PRNGKey(seed=42)

In [None]:
training_size = 10000
validation_size = 10000
test_size = 20000
spacing = 40

In [None]:
BATCH_SIZE_TR = training_size
BATCH_SIZE_VAL_TE = validation_size

In [None]:
train, val, test = DatasetGenerator.make_tuc_logo(
    key, 
    training_size, 
    validation_size, 
    test_size, 
    dtype=PRECISION_PREFIX,
)
normal_sample = random.normal(random.fold_in(key, 42), shape=(validation_size, 2), dtype=PRECISION_PREFIX)

In [None]:
DATASET_PREFIX = 'tuc-logo'

In [None]:
TRAIN_GENERATOR = DatasetGenerator.make_tuc_logo_tr
TRAIN_GEN_KWARGS = {}

## 1 Potential

In [None]:
NPOT_PREFIX = '1pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=64,
        resnet_depth=2,
        rank=20,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=1,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 3 Potential

In [None]:
NPOT_PREFIX = '3pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=64,
        resnet_depth=2,
        rank=20,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=3,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {}
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 5 Potential

In [None]:
NPOT_PREFIX = '5pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=64,
        resnet_depth=2,
        rank=20,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=5,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)

## 10 Potential

In [None]:
NPOT_PREFIX = '10pot'

In [None]:
key = random.PRNGKey(42)
data_key, key = random.split(key, 2)

In [None]:
EPOCHS = 20000
MODEL = OTF.OTF(
        input_dim=2, 
        hidden_dim=64,
        resnet_depth=2,
        rank=20,
        key=key, 
        phi=OTF.Phi,
        alpha1=15.,
        alpha2=2.,
        num_blocks=10,
        t0=0.,
        t1=1.,
        num_steps=20
)
LOSS_FUNC = jit(MODEL._loss)
VAL_LOSS_FUNC = jit(MODEL.metrics)
LOSS_KWARGS = {} 
VAL_LOSS_KWARGS = {'normal_batch': normal_sample}

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=5e-2, 
    transition_steps=1000,
    transition_begin=0,
    decay_rate=0.6)

gradient_transform = optax.chain(
    optax.scale_by_adam(),              # Use the updates from adam.
    optax.scale_by_schedule(scheduler), # Adapt LR
    optax.scale(-1.0)                   # Scale updates by -1 since optax.apply_updates 
                                        # is additive and we want to descend on the loss.
)

In [None]:
params = MODEL.params
opt_state = gradient_transform.init(params)

In [None]:
report_as_csv = True
report_filename = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX]) + '.csv'

In [None]:
ret = run(train, params, opt_state, data_key)
(train_losses, val_losses, val_mmds, val_inv_errors, 
 best_epoch_mmd, best_val_mmd, best_params_mmd, 
 best_epoch_loss, best_val_loss, best_params_loss, params) = ret

In [None]:
print(f'\nBest Epoch MMD: {best_epoch_mmd} --- Val MMD: {best_val_mmd}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_mmd, EVAL_SOLVER_STEPS)
print(f'Best Model MMD --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

print(f'\nBest Epoch Loss: {best_epoch_loss} --- Val Loss: {best_val_loss}')
test_loss, test_inv_error, test_mmd = eval_val(test, best_params_loss, EVAL_SOLVER_STEPS)
print(f'Best Model Loss --- Test Loss: {test_loss}, Test Inv Error: {test_inv_error}, Test MMD: {test_mmd}')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16,6))
ax[0].semilogy(train_losses, label='train loss')
ax[0].semilogy(val_losses, label='validation loss')
ax[1].semilogy(train_losses[1000:], label='train loss')
ax[1].semilogy(val_losses[1000:], label='validation loss')
plt.legend()

In [None]:
filestr = '_'.join([MODEL_PREFIX, PRECISION_PREFIX, DATASET_PREFIX, NPOT_PREFIX])
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_params',  
    target=params,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_loss',  
    target=best_params_loss,  
    overwrite=True,   
    step=20000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir=f'checkpoints/finalfinal/{filestr}_bestparams_mmd',  
    target=best_params_mmd,  
    overwrite=True,   
    step=20000
)