In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import logging

from jax import random
import optax
import wandb
from ml_collections import config_dict

import src.models as models
from src.models import make_VAE_loss, make_VAE_eval
from src.data import get_image_dataset, NumpyLoader
from src.utils.training import TrainState, train_loop

In [None]:
config = config_dict.ConfigDict()

config.dataset_name = 'MNIST'
config.valid_percent = 0.1
config.learning_rate = 3e-3
config.epochs = 30
config.batch_size = 512

config.model_name = 'VAE'
config.model = config_dict.ConfigDict()
config.model.latent_dim = 50
config.model.learn_prior = False
config.model.convolutional = False

config.model.encoder = config_dict.ConfigDict()
config.model.encoder.posterior = 'hetero-diag-normal'
config.model.encoder.hidden_dims = [128, 256, 512]
config.model.encoder.act_fn = 'gelu'

config.model.decoder = config_dict.ConfigDict()
config.model.decoder.likelihood = 'iso-normal'
# TODO: test these ^^^ !
config.model.decoder.hidden_dims = [128, 256, 512]
config.model.decoder.act_fn = 'gelu'
config.model.decoder.image_shape = {
    'MNIST': (28, 28, 1),
    'FashionMNIST': (28, 28, 1),
    'KMNIST': (28, 28, 1),
    'SVHN': (32, 32, 3),
    'CIFAR10': (32, 32, 3),
    'CIFAR100': (32, 32, 3),
}[config.dataset_name]

os.environ['WANDB_NOTEBOOK_NAME'] = 'train_vae.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

wandb.login()

In [None]:
# Fix for annoying "WARNING:root:The use of `check_types` is deprecated and does not have any effect."
# error message produced by tfp.
logger = logging.getLogger('root')

class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return 'check_types' not in record.getMessage()

logger.addFilter(CheckTypesFilter())

In [None]:
rng = random.PRNGKey(0)

In [None]:
train_dataset, test_dataset, valid_dataset = get_image_dataset(
    dataset_name=config.dataset_name,
    valid_percent=config.valid_percent,
    flatten_img=not config.model.convolutional
)
train_loader = NumpyLoader(train_dataset, config.batch_size)
valid_loader = NumpyLoader(valid_dataset, int(60_000 * config.valid_percent))
# test_loader = NumpyLoader(test_dataset, 10_000)
# TODO: get rid of magic numbers ^

In [None]:
model_cls = getattr(models, config.model_name)
model = model_cls(**config.model.to_dict())

init_rng, rng = random.split(rng)
init_data = train_dataset[0][0]
variables = model.init(init_rng, init_data, rng)
model_state, params = variables.pop('params')
del variables

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adamw(config.learning_rate),
    model_state=model_state
)

In [None]:
state = train_loop(model, state, config, rng, make_VAE_loss, make_VAE_eval, train_loader, valid_loader, {'mode': 'disabled'})