# Beta VAE  

## Steps  
1. Import libraries  
2. Prepare data  
   ```Download  |  Transform  |  Dataloader```  
3. Define parameters  
   ```Model  |  Optimizer  |  Loss  |  Training  ```
4. Build Model  
   ```Components  ```
5. Training loop  
6. Visualize results  

## Import libraries  

In [28]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F    # New import
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import os
import time
import datetime

print(f"Imports completed at {datetime.datetime.now()}")

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Versions
print(f"Torch: {torch.__version__}, TorchVision: {torchvision.__version__}")

Imports completed at 2025-08-22 17:57:10.830933
Using device: cuda
Torch: 2.6.0+cu124, TorchVision: 0.21.0+cu124


## Define parameters  

In [29]:
# Set seed for PyTorch
seed = 42
torch.manual_seed(seed)

# Data prep params
batch_size = 128

# Model params
latent_dim = 20
beta_vae = 4.0

# Optimizer params
learning_rate = 0.0002
# beta1 = 0.5  # Adam optimizer beta1

# # Loss params
# criterion = nn.BCELoss()

# Training params
num_epochs = 100

## Prepare data  

In [30]:
# Transform
transform = transforms.ToTensor()
print(f"Transform to be applied:\n{transform}\n")

# Load MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
print(f"Train data:\n{train_dataset}\n")

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) # Last batch will be 96 (6000 % 128) instead of 128, so drop to avoid problem with batch norm
print(f"Data loader:\n{train_loader}")

Transform to be applied:
ToTensor()

Train data:
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

Data loader:
<torch.utils.data.dataloader.DataLoader object at 0x7d3106979350>


## Build Model  
```Beta VAE```  

**Probabilistic Neural Networks**  

- Encoder
    - Takes image, outputs mean and log var of latent space dim  
- Reparametrization (z-sampler)
    - Takes the mean and log var output by the Encoder and creates new z using epsilon (latent dim) from std normal  
- Decoder
    - Takes the z and outputs mean of x_hat (note that the var of x_hat is 1)  

In [31]:
class BetaVAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(BetaVAE, self).__init__()
        self.latent_dim = latent_dim

        #--------------Encoder--------------
        self.encoder = nn.Sequential(
            # Input: (B, 1, 28, 28)
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),    # [B, 32, 14, 14]
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),    # [B, 64, 7, 7]
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),    # [B, 128, 4, 4]
            nn.ReLU(True),
        )
        self.flatten = nn.Flatten()
        self.mu = nn.Linear(128 * 4 * 4, self.latent_dim)
        self.logvar = nn.Linear(128 * 4 * 4, self.latent_dim)

        #--------------Decoder--------------
        self.fc_decode = nn.Linear(self.latent_dim, 128*4*4)
        self.decoder = nn.Sequential(
            # Input: z (B, latent_dim)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),    # [B, 64, 8, 8]
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # [B, 32, 16, 16]
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),    # [B, 1, 32, 32]
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        h_flat = self.flatten(h)
        mu = self.mu(h_flat)
        logvar = self.logvar(h_flat)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        eps = torch.randn_like(mu)
        std = torch.exp(0.5 * logvar)
        z = mu + std * eps
        return z

    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(-1, 128, 4, 4)
        x_hat = self.decoder(h)
        return x_hat[:, :, 28, 28]    # Crop to MNIST size
        
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

In [32]:
model = BetaVAE(latent_dim)
model

BetaVAE(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU(inplace=True)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mu): Linear(in_features=2048, out_features=20, bias=True)
  (logvar): Linear(in_features=2048, out_features=20, bias=True)
  (fc_decode): Linear(in_features=20, out_features=2048, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): Sigmoid()
  )
)

## Set up Optimizers  

In [33]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0002
    maximize: False
    weight_decay: 0
)

## Set up Loss functions  

In [34]:
def beta_vae_loss(x, x_hat, mu, logvar, beta=4.0):
    # Recon loss + beta * KL Div

    # Reconstruction loss: binary cross entropy
    recon_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')

    # KL Divergence loss: between q(z|x) and N(0,I)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Total loss
    loss = recon_loss + beta * kl_div

    return loss, recon_loss, kl_div

## Training loop  

In [35]:
def vae_trainer(num_epochs, latent_dim, beta):
    losses, recon_losses, kl_divs = [], [], []
    for epoch in range(num_epochs):
        model.train()
        total_loss, total_recon, total_kl = 0, 0, 0

        for batch in train_loader:
            x, _ = batch
            x = x.to(device)

            x_hat, mu, logvar = model(x)
            loss, recon, kl = beta_vae_loss(x, x_hat, mu, logvar, beta)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_recon += recon.item()
            total_kl += kl.item()

        losses.append(total_loss / len(train_loader.dataset))
        recon_losses.append(total_recon / len(train_loader.dataset))
        kl_divs.append(total_kl / len(train_loader.dataset))

        print(f"Epoch {epoch+1}, Loss: {losses[-1]:.2f}, Recon: {recon_losses[-1]:.2f}, KL: {kl_divs[-1]:.2f}")

    return losses, recon_losses, kl_divs

## Training  

In [36]:
start = time.perf_counter()
print(f"Training started at {start}")

Training started at 5939.28236697


In [37]:
losses, recon_losses, kl_divs = vae_trainer(num_epochs=num_epochs, 
                                            latent_dim=latent_dim, 
                                            beta=beta_vae
                                           )

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

## Code for visualizations  

In [None]:
# Check training time
elapsed = time.perf_counter() - start
print(f"Training time: {datetime.timedelta(seconds=int(elapsed))} ({elapsed:.2f} s)")