# Training a VAE

### Imports

In [1]:
import torch
import mlflow
import vae
from utils.mlflow import backend_stores
from utils.trainer import TrainingArguments
from vae.models import VAEConfig
from vae.models.architectures import VAEModelV1
from vae import train_vae_on_dataset

from utils.data import load_datasets
vae.models.base.model_store = "pretrained_models/MNIST"

### Seeding, loading data & setting up mlflow logging

In [2]:
DATASET = "MNIST"
DATASET_LIMIT = 50
SEED = 1337

# set the backend store uri of mlflow
mlflow.set_tracking_uri(getattr(backend_stores, DATASET))
# seed torch
torch.manual_seed(SEED)
# load datasets
train_dataset, vae_train_dataset, val_dataset, test_dataset = load_datasets(DATASET)
print(len(train_dataset), len(vae_train_dataset), len(test_dataset))

50 51000 10000


### VAE Parameters

In [3]:
EPOCHS = 100
Z_DIM = 20
BETA = 1.0

## Training the VAE

In [None]:
# set mlflow experiment
mlflow.set_experiment(f"Z_DIM {Z_DIM}")
# create a vae config
vae_config = VAEConfig(z_dim=Z_DIM, beta=BETA)
# train vae
train_vae_on_dataset(
    training_args=TrainingArguments(EPOCHS, seed=SEED, batch_size=64),
    train_dataset=vae_train_dataset,
    test_dataset=val_dataset,
    vae_config=vae_config,
    model_architecture=VAEModelV1,
    save_every_n_epochs=25,
    seed=SEED,
)

INFO: 'Z_DIM 20' does not exist. Creating a new experiment
Training VAE on dataset


Training epoch 46:  45%|████▌     | 36129/79700 [09:35<09:40, 75.05it/s, bce_l=81.13, kl_l=31.43]  