# Training a CNN classifier with VAE augmented data

### Imports

In [1]:
import mlflow
import torch
from utils.mlflow import backend_stores
from utils.trainer import TrainingArguments
from vae.generation import augment_dataset_using_per_class_vaes, augment_dataset_using_single_vae, augmentations
from vae.models import VAEConfig
from evaluation.models import CNNMNIST
from evaluation.train_utils import train_model
from utils.data import load_datasets

### Seeding, loading data & setting up mlflow logging

In [2]:
SEED = 1337
DATASET = "MNIST"
DATASET_LIMIT = 50
# set the backend store uri of mlflow
mlflow.set_tracking_uri(getattr(backend_stores, DATASET))
# seed torch
torch.manual_seed(SEED)
# load datasets
train_dataset, vae_train_dataset, val_dataset, test_dataset = load_datasets(DATASET)

### The parameters for the classification task

In [3]:
training_args = TrainingArguments(
    total_steps=5000,
    batch_size=32,
    validation_intervall=200,
    save_model=False,
    seed=SEED,
    early_stopping=False,
    early_stopping_window=20,
    save_best_metric="best_acc",
)

### Parameters for data augmentation

In [4]:
MULTI_VAE = True
VAE_EPOCHS = 150
Z_DIM = 3
BETA = 1.0
AUGMENTATION = augmentations.RANDOM_NOISE
augmentation_params = {"k": 9}

## Training the CNN

In [None]:
# create a vae config
vae_config = VAEConfig(z_dim=Z_DIM, beta=BETA)
# start mlflow run in experiment
mlflow.set_experiment(f"CNN Z_DIM {Z_DIM}")
print("Training CNN" + f" augmented with {Z_DIM=}, {BETA=}, {VAE_EPOCHS=}" if AUGMENTATION is not None else "")
with mlflow.start_run(run_name=AUGMENTATION or "baseline"):
    # log dataset limit
    mlflow.log_param("dataset_limit", DATASET_LIMIT)
    # perform data augmentation if specified
    if AUGMENTATION is not None:
        # log MULTI_VAE flag
        mlflow.log_param("multi_vae", MULTI_VAE)
        # log vae epochs
        mlflow.log_param("vae_epochs", VAE_EPOCHS)
        # if MULTI_VAE, augment data of each label seperately
        if MULTI_VAE:
            train_dataset = augment_dataset_using_per_class_vaes(
                train_dataset, vae_config, VAE_EPOCHS, AUGMENTATION, augmentation_params, seed=SEED
            )
        # else: conventional vae data augmentation
        else:
            train_dataset = augment_dataset_using_single_vae(
                train_dataset, vae_config, VAE_EPOCHS, AUGMENTATION, augmentation_params, seed=SEED
            )
    # train cnn
    results = train_model(
        model=CNNMNIST(),
        training_args=training_args,
        train_dataset=train_dataset,
        dev_dataset=val_dataset,
        test_dataset=test_dataset,
        seed=SEED,
    )
    # print the results
    print(results)

Training CNN augmented with Z_DIM=3, BETA=1.0, VAE_EPOCHS=150


                                               

  LABEL ORIGINAL AUGMENTED
0     0        5        50
1     1        5        50
2     2        4        40
3     3        4        40
4     4        7        70
5     5        8        80
6     6        4        40
7     7        5        50
9     9        8        80


Training epoch 138:  44%|████▍     | 2199/5000 [00:54<01:05, 42.62it/s, loss=0.02]