# Training a VAE

### Imports

In [1]:
import torch
import mlflow
from utils.mlflow import backend_stores
from utils.trainer import TrainingArguments
from vae.models import VAEConfig
from vae.models.architectures import VAEModelV2
from vae import train_vae_on_dataset

from utils.data import load_datasets

### Seeding, loading data & setting up mlflow logging

In [2]:
DATASET = "CIFAR10"
DATASET_LIMIT = 50
SEED = 1337

# set the backend store uri of mlflow
mlflow.set_tracking_uri(getattr(backend_stores, DATASET))
# seed torch
torch.manual_seed(SEED)
# load datasets
train_dataset, vae_train_dataset, val_dataset, test_dataset = load_datasets(DATASET)

### VAE Parameters

In [3]:
EPOCHS = 500
Z_DIM = 50
BETA = 1.0

## Training the VAE

In [None]:
# set mlflow experiment
mlflow.set_experiment(f"Z_DIM {Z_DIM}")
# create a vae config
vae_config = VAEConfig(z_dim=Z_DIM, beta=BETA)
# train vae
train_vae_on_dataset(
    training_args=TrainingArguments(EPOCHS, seed=SEED, batch_size=64),
    train_dataset=vae_train_dataset,
    test_dataset=val_dataset,
    vae_config=vae_config,
    model_architecture=VAEModelV2,
    save_every_n_epochs=25,
    seed=SEED,
)

Training VAE on dataset


Training epoch 179:  36%|███▌      | 119008/332500 [31:41<52:26, 67.84it/s, bce_l=1803.66, kl_l=51.87]  