In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import logging

from jax import random
import wandb

from src.models import make_invVAE_loss, make_invVAE_eval
from src.data import get_image_dataset, NumpyLoader
from src.utils.training import setup_training, train_loop
from experiments.inv_vae_mnist_default import get_config

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'train_vae.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

wandb.login()

In [None]:
# Fix for annoying "WARNING:root:The use of `check_types` is deprecated and does not have any effect."
# error message produced by tfp.
logger = logging.getLogger('root')

class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return 'check_types' not in record.getMessage()

logger.addFilter(CheckTypesFilter())

In [None]:
rng = random.PRNGKey(0)

In [None]:
config = get_config()

In [None]:
train_dataset, test_dataset, val_dataset = get_image_dataset(
    dataset_name=config.dataset_name,
    val_percent=config.val_percent,
    flatten_img=config.model.architecture == 'MLP',
)
train_loader = NumpyLoader(train_dataset, config.batch_size)
val_loader = NumpyLoader(val_dataset, config.batch_size)
test_loader = NumpyLoader(test_dataset, config.batch_size)

In [None]:
setup_rng, rng = random.split(rng)
init_data = train_dataset[0][0]

model, state = setup_training(config, setup_rng, init_data)

In [None]:
state = train_loop(
    model, state, config, rng, make_invVAE_loss, make_invVAE_eval, train_loader, val_loader,
    # test_loader,
    wandb_kwargs={
        # 'mode': 'offline',
        # 'notes': '',
    },
)