In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import logging

from jax import random
import optax
import wandb
from ml_collections import config_dict

import src.models as models
from src.models import make_VAE_loss, make_VAE_eval
from src.data import get_image_dataset, NumpyLoader, METADATA
from src.utils.training import TrainState, train_loop

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'train_vae.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

In [None]:
config = config_dict.ConfigDict()

config.dataset_name = 'MNIST'
config.val_percent = 0.1
config.batch_size = 512

config.learning_rate = 1e-4
config.peak_learning_rate = 3 * config.learning_rate
config.final_learning_rate = 0.1 * config.learning_rate
config.peak_lr_percent = 0.2
config.weight_decay = 1e-4
config.epochs = 100

config.model_name = 'VAE'
config.model = config_dict.ConfigDict()
config.model.latent_dim = 128
config.model.learn_prior = False
config.model.convolutional = True

config.model.encoder = config_dict.ConfigDict()
config.model.encoder.posterior = 'hetero-diag-normal'
config.model.encoder.hidden_dims = [64, 128, 256]
config.model.encoder.act_fn = 'gelu'

config.model.decoder = config_dict.ConfigDict()
config.model.decoder.likelihood = 'iso-normal'
config.model.decoder.hidden_dims = list(reversed(config.model.encoder.hidden_dims))
# config.model.decoder.hidden_dims = [32, 64, 128]
config.model.decoder.act_fn = 'gelu'
config.model.decoder.image_shape = METADATA['image_shape'][config.dataset_name]

wandb.login()

In [None]:
# Fix for annoying "WARNING:root:The use of `check_types` is deprecated and does not have any effect."
# error message produced by tfp.
logger = logging.getLogger('root')

class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return 'check_types' not in record.getMessage()

logger.addFilter(CheckTypesFilter())

In [None]:
rng = random.PRNGKey(0)

In [None]:
train_dataset, test_dataset, val_dataset = get_image_dataset(
    dataset_name=config.dataset_name,
    val_percent=config.val_percent,
    flatten_img=not config.model.convolutional,
)
train_loader = NumpyLoader(train_dataset, config.batch_size)
val_loader = NumpyLoader(val_dataset, config.batch_size)
test_loader = NumpyLoader(test_dataset, config.batch_size)

In [None]:
model_cls = getattr(models, config.model_name)
model = model_cls(**config.model.to_dict())

init_rng, rng = random.split(rng)
init_data = train_dataset[0][0]
variables = model.init(init_rng, init_data, rng)
model_state, params = variables.pop('params')
del variables

total_steps = config.epochs * len(train_loader)
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=config.learning_rate, peak_value=config.peak_learning_rate,
    warmup_steps=int(total_steps * config.peak_lr_percent),
    decay_steps=total_steps, end_value=config.final_learning_rate
)

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adamw(lr_schedule, weight_decay=config.weight_decay),
    model_state=model_state
)

In [None]:
state = train_loop(
    model, state, config, rng, make_VAE_loss, make_VAE_eval, train_loader, val_loader,
    # test_loader,
    # wandb_kwargs={'mode': 'offline'},
)