# Section 4.2.2. VAE Implementation

In [None]:
import os
import matplotlib.pyplot as plt
from torchsummary import summary

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *

torch.manual_seed(0)

In [None]:
# -------------------------------
# Define the VAE model
# -------------------------------
class VAE(nn.Module):
    def __init__(self, z_dim, input_dim):
        super(VAE, self).__init__()
        self.z_dim = z_dim

        # Encoder
        self.enc_conv = nn.Sequential(
            nn.Conv2d(input_dim, 32, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.enc_fc_mu = nn.Linear(128 * 8 * 8, z_dim)
        self.enc_fc_logvar = nn.Linear(128 * 8 * 8, z_dim)

        # Decoder
        self.dec_fc = nn.Sequential(
            nn.Linear(z_dim, 128 * 8 * 8),
            nn.ReLU(True)
        )
        self.dec_conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, input_dim, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        x = self.enc_conv(x)
        x = x.view(x.size(0), -1)
        mu = self.enc_fc_mu(x)
        logvar = self.enc_fc_logvar(x)
        return mu, logvar

    def decode(self, z):
        z = self.dec_fc(z)
        z = z.view(z.size(0), 128, 8, 8)
        z = self.dec_conv(z)
        return z

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

## 1. Training Stage

In [None]:
# -------------------------------
# Loss function
# -------------------------------
def loss_fn(recon_x, x, mu, logvar):
    """
    MSE based reconstruction loss and KL divergence loss for VAE.
    KL divergence between the latent distribution and the standard normal distribution.
    """
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

# -------------------------------
# Training
# -------------------------------
def train_vae_model(model, dataloader, optimizer, num_epochs, device):
    model.train()
    total_samples = len(dataloader.dataset)
    for epoch in range(num_epochs):
        train_loss = 0.0
        for data in dataloader:
            optimizer.zero_grad()
            data = data.to(device)
            batch_size = data.size(0)
            recon_data, mu, logvar = model(data)
            loss = loss_fn(recon_data, data, mu, logvar)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * batch_size
        avg_loss = train_loss / total_samples
        if (epoch+1) % 5 == 0 or epoch == 0:
            print(f"Epoch [{epoch+1:02}/{num_epochs}], Loss: {avg_loss:>11.4f}")
    return model

![alt text](img/VAE_edit.png "Title")

The loss is defined as

- $ \begin{align} \mathcal{L}(\mathbf{x};\theta,\phi) &=  \mathbb{E}_{\mathbf{z} \sim q_{\phi} (\mathbf{z} | \mathbf{x})} \left[ \log p_{\theta}(\mathbf{x}|\mathbf{z})  \right]- D_{KL} \left( q_{\phi} \left(\mathbf{z}|\mathbf{x} \right) || p (\mathbf{z})  \right) \\ &= \text{Reconstruction Loss + Regularization} \end{align} $

The equation is simplified as

- $D_{KL}(\mathcal N_1({\mu_1, \sigma_1}^2)) || \mathcal N_2({\mu_2, \sigma_2}^2))=\log {\frac{\sigma_2}{\sigma_1}}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-{1\over 2}$
- $D_{KL}(\mathcal N_1({\mu_1, \sigma_1}^2))||\mathcal N(0, 1))=- \frac{1}{2}\left(1 + 2\log \sigma_1- \mu_1^2 -\sigma_1^2   \right)$

In [None]:
data_path = '../dataset/i24_normalized.pt' 
batch_size = 250
num_epochs = 50
z_dim = 64
lr = 0.0005
beta1 = 0.8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset, data_loader = load_and_preprocess_data(data_path, batch_size)

input_dim = dataset.shape[1]  # dataset.shape: (40000, 1, 64, 64) → input_dim = 1

model = VAE(z_dim, input_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(beta1, 0.98))

In [None]:
print("VAE summary:")
summary(model, (input_dim, 64, 64))

In [None]:
model = train_vae_model(model, data_loader, optimizer, num_epochs, device)

os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/vae_model.pth")
print("VAE model saved.")

## 2. Testing Stage

In [None]:
def visualize_generated_samples(model, z_dim, device, nrows=8, ncols=8, save_path="img/VAE_result.png"):
    model.eval()
    with torch.no_grad():
        z = torch.randn(nrows * ncols, z_dim).to(device)
        samples = model.decode(z).cpu().numpy()

    fig, ax = plt.subplots(nrows, ncols, figsize=(8, 8))
    for i in range(nrows):
        for j in range(ncols):
            idx = i * ncols + j
            ax[i, j].imshow(samples[idx][0, :, :], origin="lower", cmap="viridis")
            ax[i, j].axis("off")
    plt.suptitle("Generated Samples", fontsize=16)

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=500)
    plt.show()
    plt.close()

In [None]:
# Can use CPU in the inference, but highly recommended to use GPU
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VAE(z_dim, input_dim).to(device)
model.load_state_dict(torch.load("models/vae_model.pth", map_location=device))
model.eval()

os.makedirs("img", exist_ok=True)
visualize_generated_samples(model, z_dim, device, nrows=8, ncols=8, save_path="img/VAE_result.png")