# VAE Training

## imports

In [1]:
import tensorflow as tf
tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[0], True)

2025-12-31 16:29:04.015233: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  if not hasattr(np, "object"):


In [2]:
import sys
import os
sys.path.insert(0, '../../..')
sys.path.insert(0, '../..')
sys.path.insert(0, '..')

from src.models.VAE import VariationalAutoencoder
from src.utils.loaders import load_mnist

import wandb
from wandb.integration.keras import WandbMetricsLogger
from utils.callbacks import LRFinder, get_lr_scheduler, get_early_stopping, LRLogger
from utils.wandb_utils import init_wandb


In [3]:
# Global Configuration
BATCH_SIZE = 1024
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100
INITIAL_EPOCH = 0
INPUT_DIM = (28,28,1)
Z_DIM = 2
OPTIMIZER_NAME = 'adam'
DATASET_NAME = 'digits'
MODEL_TYPE = 'vae'

# Run Params
SECTION = 'vae'
RUN_ID = '0002'
RUN_FOLDER = '../run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATASET_NAME])

if not os.path.exists(RUN_FOLDER):
    os.makedirs(RUN_FOLDER, exist_ok=True)
    os.makedirs(os.path.join(RUN_FOLDER, 'viz'), exist_ok=True)
    os.makedirs(os.path.join(RUN_FOLDER, 'images'), exist_ok=True)
    os.makedirs(os.path.join(RUN_FOLDER, 'weights'), exist_ok=True)

MODE = 'build'


## data

In [4]:
(x_train, y_train), (x_test, y_test) = load_mnist()

## architecture

In [5]:
vae = VariationalAutoencoder(
    input_dim = INPUT_DIM
    , encoder_conv_filters = [32,64,64, 64]
    , encoder_conv_kernel_size = [3,3,3,3]
    , encoder_conv_strides = [1,2,2,1]
    , decoder_conv_t_filters = [64,64,32,1]
    , decoder_conv_t_kernel_size = [3,3,3,3]
    , decoder_conv_t_strides = [1,2,2,1]
    , z_dim = Z_DIM
)

if MODE == 'build':
    vae.save(RUN_FOLDER)
else:
    vae.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.weights.h5'))

I0000 00:00:1767198550.188542    7748 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6094 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2070, pci bus id: 0000:01:00.0, compute capability: 7.5


In [6]:
vae.encoder.summary()

In [7]:
vae.decoder.summary()

## training

In [8]:
LEARNING_RATE = 0.0005
R_LOSS_FACTOR = 1000

In [9]:
# VAE models cannot use LRFinder due to Lambda layer with custom 'sampling' function
# that isn't registered with Keras serialization. Using fixed learning rate instead.
OPTIMAL_LR = LEARNING_RATE
print(f"Using Learning Rate: {OPTIMAL_LR}")


Using Learning Rate: 0.0005


In [10]:
# Initialize WandB with correct LR
run = init_wandb(
    name=f"vae_{DATASET_NAME}_{RUN_ID}",
    project="generative-deep-learning",
    config={
        "model": MODEL_TYPE,
        "dataset": DATASET_NAME,
        "learning_rate": OPTIMAL_LR,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mcataluna84[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
vae.compile(OPTIMAL_LR, R_LOSS_FACTOR)


In [12]:
# Train with callbacks
vae.train(
    x_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    run_folder=RUN_FOLDER,
    print_every_n_batches=PRINT_EVERY_N_BATCHES,
    initial_epoch=INITIAL_EPOCH,
    lr_decay=1, # Disable internal scheduler to use external
    extra_callbacks=[
        WandbMetricsLogger(),
        LRLogger(),
        get_lr_scheduler(monitor='loss', patience=5),
        get_early_stopping(monitor='loss', patience=10)
    ]
)



Epoch 1/200


2025-12-31 16:29:26.410527: I external/local_xla/xla/service/service.cc:163] XLA service 0x7c070c049a50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-12-31 16:29:26.410586: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 2070, Compute Capability 7.5
2025-12-31 16:29:26.579419: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-12-31 16:29:27.228238: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91701
2025-12-31 16:29:27.997546: I external/local_xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:546] Omitted potentially buggy algorithm eng14{k25=2} for conv (f32[1024,32,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1024,1,28,28]{3,2,1,0}, f32[32,1,3,3]{3,2,1,0}, f32[32]{0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call

[1m58/59[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 50ms/step - loss: 190.0872 - vae_r_loss: 188.4018

2025-12-31 16:29:40.320546: I external/local_xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:546] Omitted potentially buggy algorithm eng14{k25=2} for conv (f32[608,32,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[608,1,28,28]{3,2,1,0}, f32[32,1,3,3]{3,2,1,0}, f32[32]{0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}
2025-12-31 16:29:40.659153: I external/local_xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:546] Omitted potentially buggy algorithm eng14{k25=2} for conv (f32[608,64,7,7]{3,2,1,0}, u8[0]{0}) custom-call(f32[608,64,7,7]{3,2,1,0}, f32[64,64,3,3]{3,2,1,0}, f32[64]{0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cud

[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 170ms/step - loss: 189.1838 - vae_r_loss: 187.4832
Epoch 1: saving model to ../run/vae/0002_digits/weights/weights-001-136.79.weights.h5

Epoch 1: finished saving model to ../run/vae/0002_digits/weights/weights-001-136.79.weights.h5

Epoch 1: saving model to ../run/vae/0002_digits/weights/weights.weights.h5

Epoch 1: finished saving model to ../run/vae/0002_digits/weights/weights.weights.h5
Epoch 1: Learning Rate is 5.00e-04
[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 176ms/step - loss: 136.7860 - vae_r_loss: 134.2039 - learning_rate: 5.0000e-04
Epoch 2/200
[1m58/59[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 51ms/step - loss: 68.0788 - vae_r_loss: 65.4910
Epoch 2: saving model to ../run/vae/0002_digits/weights/weights-002-65.59.weights.h5

Epoch 2: finished saving model to ../run/vae/0002_digits/weights/weights-002-65.59.weights.h5

Epoch 2: saving model to ../run/vae/0002_digits/weights/we

In [13]:
wandb.finish()


0,1
epoch/epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇██
epoch/learning_rate,█████████████████████████████▁▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch/vae_r_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
learning_rate,██████████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch/epoch,199.0
epoch/learning_rate,0.00025
epoch/loss,44.3583
epoch/vae_r_loss,39.03543
learning_rate,0.00025


# Cleanup: Restart Kernel to Release GPU Memory

In [None]:
# ═══════════════════════════════════════════════════════════════════════════════
# CLEANUP: Restart kernel to fully release GPU memory
# ═══════════════════════════════════════════════════════════════════════════════
# TensorFlow/CUDA does not release GPU memory within a running Python process.
# Restarting the kernel is the only guaranteed way to free all GPU resources.
# Run this cell only after all work is complete and saved.

import IPython
print("Restarting kernel to release GPU memory...")
IPython.Application.instance().kernel.do_shutdown(restart=True)