# Training a CNN classifier with VAE augmented data

### Imports

In [1]:
import mlflow
import torch
import vae
from utils.mlflow import backend_stores
from utils.trainer import TrainingArguments
from vae.models import VAEConfig
from evaluation.models import CNNMNIST
from evaluation.train_utils import train_model
from utils.data import get_dataset, load_splitted_datasets
from vae.generation import augmentations
from vae.generation2 import DataAugmentation
from torch.utils.data import ConcatDataset

### Seeding, loading data & setting up mlflow logging

In [2]:
SEED = 1337
DATASET = "MNIST"
# seed torch
torch.manual_seed(SEED)
# set model store path
vae.models.base.model_store = "pretrained_models/MNIST"
# set the backend store uri of mlflow
mlflow.set_tracking_uri(getattr(backend_stores, DATASET))

### Parameters for data augmentation

In [3]:
MULTI_VAE = False
VAE_EPOCHS = 100
Z_DIM = 10
BETA = 1.0
AUGMENTATION = augmentations.REPARAMETRIZATION
K = 450
augmentation_params = {}

## Data Augmentation

In [4]:
datasets, dataset_info = load_splitted_datasets(DATASET)

mlflow.set_experiment(f"CNN Z_DIM {Z_DIM}" if AUGMENTATION else "CNN Baseline")
with mlflow.start_run(run_name=AUGMENTATION or "baseline") as run:
    if AUGMENTATION is not None:
        da = DataAugmentation(
            vae_config=VAEConfig(z_dim=Z_DIM, beta=BETA),
            vae_epochs=VAE_EPOCHS,
            multi_vae=MULTI_VAE,
            seed=SEED,
        )

        print(", ".join([str(len(ds)) for ds in datasets]))
        print(sum([len(ds) for ds in datasets]))

        generated_datasets = da.augment_datasets(datasets, dataset_info, AUGMENTATION, K=K, **augmentation_params)

        print(", ".join([str(len(ds)) for ds in generated_datasets]))
        print(sum([len(ds) for ds in generated_datasets]))

        augmented_dataset = ConcatDataset([*datasets, *generated_datasets])
    else:
        augmented_dataset = ConcatDataset([*datasets])

test_dataset = get_dataset(DATASET, train=False)

5, 5, 5, 5, 5, 5, 5, 5, 5, 5
50


                                               

n = 10
x = tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])
L = tensor([45, 45, 45, 45, 45, 45, 45, 45, 45, 45])
45, 45, 45, 45, 45, 45, 45, 45, 45, 45
450




### The parameters for the classification task

In [5]:
training_args = TrainingArguments(
    total_steps=5000,
    batch_size=32,
    validation_intervall=200,
    save_model=False,
    seed=SEED,
    early_stopping=False,
    early_stopping_window=20,
    save_best_metric="best_acc",
)

## Training the CNN

In [6]:
# create a vae config
vae_config = VAEConfig(z_dim=Z_DIM, beta=BETA)
# start mlflow run in experiment
with mlflow.start_run(run.info.run_id):
    # train cnn
    results = train_model(
        model=CNNMNIST(),
        training_args=training_args,
        train_dataset=augmented_dataset,
        dev_dataset=test_dataset,
        test_dataset=test_dataset,
        seed=SEED,
    )
    # print the results
    print(results)

Training epoch 313: 100%|██████████| 5000/5000 [01:46<00:00, 47.10it/s, loss=0.01]

Evaluating model.





{'eval_loss': 0.9655145559310913, 'eval_acc': 0.8398, 'eval_f1': 0.836845494129255}
