# VAE Training

## imports

In [None]:
import tensorflow as tf
tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[0], True)

In [None]:
import sys
import os
sys.path.insert(0, '../../..')
sys.path.insert(0, '../..')
sys.path.insert(0, '..')

from src.models.VAE import VariationalAutoencoder
from src.utils.loaders import load_mnist

import wandb
from wandb.integration.keras import WandbMetricsLogger
from utils.callbacks import LRFinder, get_lr_scheduler, get_early_stopping, LRLogger
from utils.wandb_utils import init_wandb


In [None]:
# Global Configuration
BATCH_SIZE = 1024
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100
INITIAL_EPOCH = 0
INPUT_DIM = (28,28,1)
Z_DIM = 2
OPTIMIZER_NAME = 'adam'
DATASET_NAME = 'digits'
MODEL_TYPE = 'vae'

# Run Params
SECTION = 'vae'
RUN_ID = '0002'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATASET_NAME])

if not os.path.exists(RUN_FOLDER):
    os.makedirs(RUN_FOLDER, exist_ok=True)
    os.makedirs(os.path.join(RUN_FOLDER, 'viz'), exist_ok=True)
    os.makedirs(os.path.join(RUN_FOLDER, 'images'), exist_ok=True)
    os.makedirs(os.path.join(RUN_FOLDER, 'weights'), exist_ok=True)

MODE = 'build'


## data

In [None]:
(x_train, y_train), (x_test, y_test) = load_mnist()

## architecture

In [None]:
vae = VariationalAutoencoder(
    input_dim = INPUT_DIM
    , encoder_conv_filters = [32,64,64, 64]
    , encoder_conv_kernel_size = [3,3,3,3]
    , encoder_conv_strides = [1,2,2,1]
    , decoder_conv_t_filters = [64,64,32,1]
    , decoder_conv_t_kernel_size = [3,3,3,3]
    , decoder_conv_t_strides = [1,2,2,1]
    , z_dim = Z_DIM
)

if MODE == 'build':
    vae.save(RUN_FOLDER)
else:
    vae.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.weights.h5'))

In [None]:
vae.encoder.summary()

In [None]:
vae.decoder.summary()

## training

In [None]:
LEARNING_RATE = 0.0005
R_LOSS_FACTOR = 1000

In [None]:
# VAE models cannot use LRFinder due to Lambda layer with custom 'sampling' function
# that isn't registered with Keras serialization. Using fixed learning rate instead.
OPTIMAL_LR = LEARNING_RATE
print(f"Using Learning Rate: {OPTIMAL_LR}")


In [None]:
# Initialize WandB with correct LR
run = init_wandb(
    name=f"vae_{DATASET_NAME}_{RUN_ID}",
    project="generative-deep-learning",
    config={
        "model": MODEL_TYPE,
        "dataset": DATASET_NAME,
        "learning_rate": OPTIMAL_LR,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
    }
)

In [None]:
vae.compile(OPTIMAL_LR, R_LOSS_FACTOR)


In [None]:
# Train with callbacks
vae.train(
    x_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    run_folder=RUN_FOLDER,
    print_every_n_batches=PRINT_EVERY_N_BATCHES,
    initial_epoch=INITIAL_EPOCH,
    lr_decay=1, # Disable internal scheduler to use external
    extra_callbacks=[
        WandbMetricsLogger(),
        LRLogger(),
        get_lr_scheduler(monitor='loss', patience=5),
        get_early_stopping(monitor='loss', patience=10)
    ]
)



In [None]:
wandb.finish()
