In [None]:
# Imports

import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms


In [None]:
# Configuration and hyperparameters

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# VAE hyperparameters
vae_latent_dim = 16
vae_hidden_dim = 400
vae_epochs = 50
vae_batch_size = 128
vae_lr = 1e-3

# Diffusion hyperparameters
diffusion_T = 250         # number of diffusion steps
diffusion_epochs = 50
diffusion_batch_size = 128
diffusion_lr = 2e-4


In [None]:
# MNIST dataset

transform = transforms.ToTensor()

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)
test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)

train_loader_vae = DataLoader(train_dataset, batch_size=vae_batch_size, shuffle=True)
test_loader_vae = DataLoader(test_dataset, batch_size=vae_batch_size, shuffle=False)

print("Train size:", len(train_dataset), "Test size:", len(test_dataset))


## 1. VAE: MNIST → Latent → MNIST

We use a simple fully-connected VAE:
- Encoder: 784 → 400 → (μ, logσ²) in 16D latent space
- Decoder: 16 → 400 → 784


In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=16):
        super().__init__()
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc2(z))
        logits = self.fc3(h)
        x_hat = torch.sigmoid(logits)
        return x_hat

    def forward(self, x):
        x_flat = x.view(x.size(0), -1)
        mu, logvar = self.encode(x_flat)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar


In [None]:
def vae_loss(x, x_hat, mu, logvar):
    x = x.view(x.size(0), -1)
    bce = F.binary_cross_entropy(x_hat, x, reduction="sum")
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return bce + kl, bce, kl


In [None]:
vae = VAE(input_dim=784, hidden_dim=vae_hidden_dim, latent_dim=vae_latent_dim).to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=vae_lr)

print(vae)


In [None]:
# Train VAE

vae.train()
for epoch in range(1, vae_epochs + 1):
    total_loss = 0.0
    total_bce = 0.0
    total_kl = 0.0
    n_samples = 0

    for x, _ in train_loader_vae:
        x = x.to(device)

        vae_optimizer.zero_grad()
        x_hat, mu, logvar = vae(x)
        loss, bce, kl = vae_loss(x, x_hat, mu, logvar)
        loss.backward()
        vae_optimizer.step()

        batch_size = x.size(0)
        total_loss += loss.item()
        total_bce += bce.item()
        total_kl += kl.item()
        n_samples += batch_size

    avg_loss = total_loss / n_samples
    avg_bce = total_bce / n_samples
    avg_kl = total_kl / n_samples

    print(f"[VAE] Epoch {epoch:02d} | Loss: {avg_loss:.4f} (BCE {avg_bce:.4f}, KL {avg_kl:.4f})")


In [None]:
# Visualize some VAE reconstructions

vae.eval()
x_batch, y_batch = next(iter(test_loader_vae))
x_batch = x_batch.to(device)

with torch.no_grad():
    x_hat, mu, logvar = vae(x_batch)

x_batch = x_batch.cpu()
x_hat = x_hat.cpu().view(-1, 1, 28, 28)

n = 8
plt.figure(figsize=(2 * n, 4))

for i in range(n):
    plt.subplot(2, n, i + 1)
    plt.imshow(x_batch[i, 0].numpy(), cmap="gray")
    plt.axis("off")
    plt.title(f"Orig: {y_batch[i].item()}")

    plt.subplot(2, n, n + i + 1)
    plt.imshow(x_hat[i, 0].numpy(), cmap="gray")
    plt.axis("off")
    plt.title("Recon")

plt.tight_layout()
plt.show()


## 2. Latent Diffusion: DDPM in VAE Latent Space

We freeze the VAE and train a diffusion model on the latent vectors μ(x).


In [None]:
vae.eval()

@torch.no_grad()
def encode_to_latent(x):
    x = x.to(device)
    x_flat = x.view(x.size(0), -1)
    mu, logvar = vae.encode(x_flat)
    return mu


In [None]:
# Diffusion utilities

T = diffusion_T

betas = torch.linspace(1e-4, 0.02, T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

betas = betas.to(device)
alphas = alphas.to(device)
alphas_cumprod = alphas_cumprod.to(device)

def q_sample(z0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(z0)
    sqrt_alpha_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1)
    sqrt_one_minus = torch.sqrt(1.0 - alphas_cumprod[t]).view(-1, 1)
    return sqrt_alpha_bar * z0 + sqrt_one_minus * noise


In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, T, dim):
        super().__init__()
        self.emb = nn.Embedding(T, dim)

    def forward(self, t):
        return self.emb(t)


class LatentDenoiser(nn.Module):
    def __init__(self, latent_dim, time_dim=32, hidden_dim=128, T=100):
        super().__init__()
        self.time_emb = TimeEmbedding(T, time_dim)
        self.fc1 = nn.Linear(latent_dim + time_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, latent_dim)

    def forward(self, z_t, t):
        t_emb = self.time_emb(t)
        x = torch.cat([z_t, t_emb], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        eps_pred = self.fc3(x)
        return eps_pred


In [None]:
diffusion_model = LatentDenoiser(
    latent_dim=vae_latent_dim,
    time_dim=32,
    hidden_dim=128,
    T=T
).to(device)

diffusion_optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=diffusion_lr)

print(diffusion_model)


In [None]:
train_loader_diffusion = DataLoader(train_dataset, batch_size=diffusion_batch_size, shuffle=True)


In [None]:
# Train diffusion model

diffusion_model.train()

for epoch in range(1, diffusion_epochs + 1):
    total_loss = 0.0
    n_batches = 0

    for x, _ in train_loader_diffusion:
        x = x.to(device)
        batch_size = x.size(0)

        with torch.no_grad():
            x_flat = x.view(x.size(0), -1)
            mu, logvar = vae.encode(x_flat)
            z0 = mu

        t = torch.randint(0, T, (batch_size,), device=device, dtype=torch.long)

        noise = torch.randn_like(z0)
        z_t = q_sample(z0, t, noise=noise)

        eps_pred = diffusion_model(z_t, t)

        loss = F.mse_loss(eps_pred, noise)

        diffusion_optimizer.zero_grad()
        loss.backward()
        diffusion_optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    avg_loss = total_loss / n_batches
    print(f"[Diffusion] Epoch {epoch:02d} | MSE loss: {avg_loss:.6f}")


In [None]:
# Sampling from diffusion + decode with VAE

@torch.no_grad()
def p_sample_step(z_t, t):
    b = z_t.size(0)
    t_tensor = torch.full((b,), t, device=device, dtype=torch.long)
    eps_pred = diffusion_model(z_t, t_tensor)

    beta_t = betas[t]
    alpha_t = alphas[t]
    alpha_bar_t = alphas_cumprod[t]
    sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)

    z0_est = (z_t - sqrt_one_minus_alpha_bar_t * eps_pred) / torch.sqrt(alpha_bar_t)

    if t > 0:
        alpha_bar_prev = alphas_cumprod[t - 1]
        coef1 = torch.sqrt(alpha_bar_prev) * beta_t / (1.0 - alpha_bar_t)
        coef2 = torch.sqrt(alpha_t) * (1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t)
        mean = coef1 * z0_est + coef2 * z_t
        var = beta_t * (1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t)
        noise = torch.randn_like(z_t)
        z_prev = mean + torch.sqrt(var) * noise
    else:
        z_prev = z0_est
    return z_prev


@torch.no_grad()
def sample_latent_and_decode(n_samples=16):
    diffusion_model.eval()
    vae.eval()

    z_t = torch.randn(n_samples, vae_latent_dim, device=device)

    for t in reversed(range(T)):
        z_t = p_sample_step(z_t, t)

    x_hat_flat = vae.decode(z_t)
    x_hat = x_hat_flat.view(-1, 1, 28, 28)
    return x_hat.cpu()


In [None]:
# Generate and visualize samples

samples = sample_latent_and_decode(n_samples=16)

plt.figure(figsize=(6, 6))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(samples[i, 0].numpy(), cmap="gray")
    plt.axis("off")
plt.suptitle("VAE Latent Diffusion Samples", y=0.92)
plt.tight_layout()
plt.show()
