<a href="https://colab.research.google.com/github/cs-iuu/ocr-2025-fall-cv/blob/main/notebooks/14.1.generative_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VAE

## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
# Dataset
transform = transforms.ToTensor()
trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)

# VAE Model
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 400),
            nn.ReLU()
        )
        self.mu = nn.Linear(400, 20)
        self.logvar = nn.Linear(400, 20)

        self.decoder = nn.Sequential(
            nn.Linear(20, 400),
            nn.ReLU(),
            nn.Linear(400, 784),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

vae = VAE()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# Loss function
def vae_loss(x, x_hat, mu, logvar):
    recon = nn.functional.binary_cross_entropy(x_hat, x.view(-1,784), reduction="sum")
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + kl

# Training
for epoch in range(5):
    total_loss = 0
    for x, _ in trainloader:
        x_hat, mu, logvar = vae(x)
        loss = vae_loss(x, x_hat, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(trainloader)}")


## Generation

In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    z = torch.randn(16, 20)
    samples = vae.decoder(z).view(-1,1,28,28)

grid = torch.cat([samples[i] for i in range(16)], dim=1)
plt.imshow(grid.squeeze(), cmap="gray")
plt.title("VAE-generated samples")
plt.show()


# GAN (DCGAN-style on MNIST)

In [None]:
import torch.nn.functional as F

# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

G = Generator()
D = Discriminator()
opt_G = optim.Adam(G.parameters(), lr=2e-4)
opt_D = optim.Adam(D.parameters(), lr=2e-4)

criterion = nn.BCELoss()

# Training loop
for epoch in range(5):
    for x, _ in trainloader:
        bs = x.size(0)
        x = x.view(bs, -1)

        # Train Discriminator
        real_labels = torch.ones(bs, 1)
        fake_labels = torch.zeros(bs, 1)

        # Real images
        outs_real = D(x)
        loss_real = criterion(outs_real, real_labels)

        # Fake images
        z = torch.randn(bs, 100)
        fake = G(z)
        outs_fake = D(fake.detach())
        loss_fake = criterion(outs_fake, fake_labels)

        loss_D = loss_real + loss_fake
        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # Train Generator
        outs_fake = D(fake)
        loss_G = criterion(outs_fake, real_labels)

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

    print(f"Epoch {epoch+1}: D_loss={loss_D.item():.4f}, G_loss={loss_G.item():.4f}")


## Generation

In [None]:
with torch.no_grad():
    z = torch.randn(16, 100)
    samples = G(z).view(-1,1,28,28)

grid = torch.cat([samples[i] for i in range(16)], dim=1)
plt.imshow(grid.squeeze(), cmap="gray")
plt.show()


# Diffusion Model (DDPM) on MNIST

In [None]:
import numpy as np

T = 300
beta = np.linspace(1e-4, 0.02, T)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)

def q_sample(x0, t, noise):
    sqrt_alpha_hat = torch.sqrt(torch.tensor(alpha_hat[t])).to(x0.device)
    sqrt_one_minus = torch.sqrt(1 - torch.tensor(alpha_hat[t])).to(x0.device)
    return sqrt_alpha_hat * x0 + sqrt_one_minus * noise


In [None]:
class UNetSmall(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 1, 3, padding=1)
        )

    def forward(self, x, t):
        return self.net(x)

model = UNetSmall()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


## Training

In [None]:
def q_sample(x0, t, noise):
    # Ensure alpha_hat[t] is converted to a tensor with float32 dtype
    # and then reshape it for broadcasting across image dimensions.
    # alpha_hat[t] will be a 1D numpy array of length `batch_size`. It needs to be reshaped
    # to `(batch_size, 1, 1, 1)` to correctly multiply with x0 (images of shape (batch_size, C, H, W)).
    sqrt_alpha_hat_t = torch.sqrt(torch.tensor(alpha_hat[t], dtype=torch.float32)).to(x0.device).view(-1, 1, 1, 1)
    sqrt_one_minus_alpha_hat_t = torch.sqrt(1 - torch.tensor(alpha_hat[t], dtype=torch.float32)).to(x0.device).view(-1, 1, 1, 1)
    return sqrt_alpha_hat_t * x0 + sqrt_one_minus_alpha_hat_t * noise

for epoch in range(5):
    for x, _ in trainloader:
        x = x.to(torch.float32)

        t = torch.randint(0, T, (x.size(0),))
        noise = torch.randn_like(x)

        x_noisy = q_sample(x, t, noise)
        noise_pred = model(x_noisy, t)

        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, loss={loss.item():.4f}")

## Sampling (Reverse Diffusion)

In [None]:
x = torch.randn(1,1,28,28)

for t in reversed(range(T)):
    noise_pred = model(x, torch.tensor([t]))
    beta_t = beta[t]
    alpha_t = alpha[t]
    alpha_hat_t = alpha_hat[t]

    if t > 0:
        noise = torch.randn_like(x)
    else:
        noise = 0

    x = (1/np.sqrt(alpha_t)) * (x - (beta_t/np.sqrt(1 - alpha_hat_t)) * noise_pred) + np.sqrt(beta_t) * noise

plt.imshow(x.squeeze().detach(), cmap="gray")
plt.show()
