# Training the VAE

### Setup

Load imports

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from importlib import reload
from mnistVAE import MNISTVAE

Set device

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


Load MNIST

In [15]:
# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torch.Tensor(datasets.MNIST('../data', train=True, download=True, transform=transform).data).unsqueeze(1) / 255.0
test_dataset = torch.Tensor(datasets.MNIST('../data', train=False, transform=transform).data).unsqueeze(1) / 255.0

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

Provide a fallback for the device

In [5]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

### Training

Create VAE instance

In [25]:
# Create VAE instance
vae = MNISTVAE(input_dim=(1, 28, 28), latent_dim=10)

# Setup training configuration
vae.setup_training(
    num_epochs=50,
    learning_rate=1e-3,
    batch_size=64,
    output_dir="mnist_vae_results"
)

# Training would look like:
# vae.train(train_data, eval_data)

Using device: mps


Train the VAE

In [26]:
vae.train(train_dataset, test_dataset)

Preprocessing train data...
INFO:pythae.pipelines.training:Preprocessing train data...
Checking train dataset...
INFO:pythae.pipelines.training:Checking train dataset...
Preprocessing eval data...

INFO:pythae.pipelines.training:Preprocessing eval data...

Checking eval dataset...
INFO:pythae.pipelines.training:Checking eval dataset...
Using Base Trainer

INFO:pythae.pipelines.training:Using Base Trainer

Model passed sanity check !
Ready for training.

INFO:pythae.trainers.base_trainer.base_trainer:Model passed sanity check !
Ready for training.

Created mnist_vae_results/VAE_training_2025-01-22_12-49-34. 
Training config, checkpoints and final model will be saved here.

INFO:pythae.trainers.base_trainer.base_trainer:Created mnist_vae_results/VAE_training_2025-01-22_12-49-34. 
Training config, checkpoints and final model will be saved here.

Training params:
 - max_epochs: 50
 - per_device_train_batch_size: 64
 - per_device_eval_batch_size: 64
 - checkpoint saving every: None
Optimize

KeyboardInterrupt: 