
# 🧪 Lab 1 — Generative Models Foundations: **GAN vs VAE** (PyTorch, MNIST)

**Course:** Generative AI (Day 1)  
**Lab Length:** ~3 hours  
**Goal:** Implement a **GAN** and a **VAE** on MNISTFashion-MNIST, compare their behavior, and reflect on stability vs. latent structure.

> ✅ **What you will submit:**  
> - A short reflection with generated image grids and comments.

### 🚦 Rules
- Keep training epochs small if you’re on CPU; you can always re-run with more epochs later.
- Cells marked **(Provided)** can be run as-is; cells marked **(TODO)** require your edits.

### 🧰 Requirements
- PyTorch, Torchvision, Matplotlib, TQDM
- (Optional) SciPy for a simple FID-like metric



## 🎯 Learning Objectives
- Implement a **vanilla GAN**: Generator + Discriminator, adversarial loss, and training loop.
- Implement a **Variational Autoencoder (VAE)**: encoder/decoder, **reparameterization trick**, and **ELBO**.
- Produce **visualizations**: sample grids, reconstructions, and **latent interpolations**.
- Compare GAN vs VAE using a **proxy FID-like** feature distance.
- Reflect on stability, mode collapse, and smoothness of latent space.



## 🧭 Tips for Success
- Use **[-1, 1]** input range for GANs (Tanh output).  
- Start with **small networks**. You can always scale up.  
- Inspect **loss curves**, generated images, and reconstructions **frequently**.
- If your GAN collapses, try: smaller LR, label smoothing, or tweak BatchNorm.


In [None]:

# ===== (Provided) Setup & Installs =====
# If you're in Colab, uncomment the next line to ensure dependencies:
# !pip -q install torch torchvision torchaudio matplotlib tqdm scipy

import math, os, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
from tqdm import tqdm

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


In [None]:

# ===== (Provided) Helper: visualization =====
def show_grid(tensor, title="", nrow=4, value_range=(-1,1)):
    grid = utils.make_grid(tensor, nrow=nrow, normalize=True, value_range=value_range)
    plt.figure(figsize=(4,4)); plt.axis('off'); plt.title(title)
    plt.imshow(grid.permute(1,2,0)); plt.show()


In [None]:

# ===== (TODO) Data: MNIST or Fashion-MNIST =====
# Hints:
# - Normalize to mean=0.5, std=0.5 to map inputs to [-1, 1] for GAN (Tanh output)
# - Use batch size around 128 if you have a GPU, smaller if on CPU

BATCH = 128  # TODO: adjust if needed
use_fashion = True  # TODO: set to True to try Fashion-MNIST

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

if use_fashion:
    train_ds = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_ds  = datasets.FashionMNIST(root='./data',  train=False, download=True, transform=transform)
else:
    train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_ds  = datasets.MNIST(root='./data',  train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)

# Quick sanity-check visualization
xb, yb = next(iter(train_loader))
show_grid(xb[:16], title="Real samples (normalized to [-1,1])")



---

# Part 1 — **GAN** (Vanilla)  *(~90 minutes)*

We will implement a simple GAN (DCGAN-ish) with:
- **Generator**: maps `z ~ N(0, I)` to image `x̂`
- **Discriminator**: scores real vs fake
- **Loss**: Hinge (recommended) **or** BCE (your choice)

> **Milestones**
> 1) Implement **Generator** & **Discriminator**  
> 2) Choose **loss** (hinge recommended), set **optimizers**  
> 3) Implement **training loop** and generate sample grids every epoch


In [None]:

# ===== (TODO) GAN Architectures =====
# Hints:
# - Use Tanh output for G (inputs are normalized to [-1, 1])
# - Use LeakyReLU in D; consider BatchNorm in G (not always in D)
# - Start small: upsample from (z_dim) -> (128*7*7) -> ConvTranspose to 14x14 -> 28x28
# - Keep IMG_CH = 1 for MNIST

Z_DIM  = 64   # TODO: try 32, 128 to see effect
IMG_CH = 1
IMG_H  = 28
IMG_W  = 28

class Generator(nn.Module):
    def __init__(self, z_dim=Z_DIM, img_ch=IMG_CH):
        super().__init__()
        # Suggested skeleton:
        # - Linear(z_dim -> 128*7*7) + BN + ReLU
        # - Unflatten to (128, 7, 7)
        # - ConvTranspose2d(128 -> 64, kernel=4, stride=2, padding=1) + BN + ReLU  # (64, 14, 14)
        # - ConvTranspose2d(64 -> 32, kernel=4, stride=2, padding=1) + BN + ReLU   # (32, 28, 28)
        # - Conv2d(32 -> 1, kernel=3, stride=1, padding=1) -> Tanh
        self.net = nn.Sequential(
            nn.Linear(z_dim, 128*7*7),
            nn.BatchNorm1d(128*7*7),
            nn.ReLU(True),
            nn.Unflatten(1, (128, 7, 7)),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, img_ch, 3, 1, 1),
            nn.Tanh()
        )
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_ch=IMG_CH):
        super().__init__()
        # Suggested skeleton:
        # - Conv2d(1 -> 32, 4, 2, 1) + LeakyReLU
        # - Conv2d(32 -> 64, 4, 2, 1) + BN + LeakyReLU
        # - Conv2d(64 -> 128, 3, 2, 1) + LeakyReLU
        # - Flatten -> Linear(128*4*4 -> 1)
        self.features = nn.Sequential(
            nn.Conv2d(img_ch, 32, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.LeakyReLU(0.2, True),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*4*4, 1)
        )
    def forward(self, x):
        f = self.features(x)
        logits = self.classifier(f).squeeze(1)
        return logits, f

G = Generator().to(device)
D = Discriminator().to(device)

# Quick shape tests
with torch.no_grad():
    z = torch.randn(2, Z_DIM, device=device)
    x_fake = G(z)
    logit, f = D(x_fake)
    assert x_fake.shape == (2, 1, 28, 28), f"Got {x_fake.shape}"
    assert logit.shape[0] == 2, f"Got {logit.shape}"
print("✓ GAN shapes look OK.")


In [None]:

# ===== (TODO) GAN Losses =====
# Option A: Hinge loss (recommended)
def d_loss_hinge(real_logits, fake_logits):
    # implement hinge: E[max(0, 1 - D(real))] + E[max(0, 1 + D(fake))]
    return F.relu(1.0 - real_logits).mean() + F.relu(1.0 + fake_logits).mean()

def g_loss_hinge(fake_logits):
    # implement generator hinge: -E[D(fake)]
    return -fake_logits.mean()

# Option B: BCE 
# TODO: To be tested
# bce = nn.BCEWithLogitsLoss()
# def d_loss_bce(real_logits, fake_logits):
#     real_t = torch.ones_like(real_logits)
#     fake_t = torch.zeros_like(fake_logits)
#     return bce(real_logits, real_t) + bce(fake_logits, fake_t)
# def g_loss_bce(fake_logits):
#     real_t = torch.ones_like(fake_logits)
#     return bce(fake_logits, real_t)


In [None]:

# ===== (TODO) GAN Training =====
# Hints:
# - Alternate D then G updates
# - Sample fresh z for each update
# - Visualize fixed z grid per epoch
LR = 2e-4
betas = (0.5, 0.999)
opt_G = torch.optim.Adam(G.parameters(), lr=LR, betas=betas)
opt_D = torch.optim.Adam(D.parameters(), lr=LR, betas=betas)

EPOCHS_GAN = 5  # Increase if you have GPU time
fixed_z = torch.randn(16, Z_DIM, device=device)

for epoch in range(1, EPOCHS_GAN+1):
    G.train(); D.train()
    pbar = tqdm(train_loader, desc=f"[GAN] Epoch {epoch}/{EPOCHS_GAN}")
    for x, _ in pbar:
        x = x.to(device)

        # (1) Update D
        z = torch.randn(x.size(0), Z_DIM, device=device)
        with torch.no_grad():
            x_fake = G(z)
        real_logits, _ = D(x)
        fake_logits, _ = D(x_fake)
        loss_D = d_loss_hinge(real_logits, fake_logits)  # or d_loss_bce(...)
        opt_D.zero_grad(set_to_none=True)
        loss_D.backward()
        opt_D.step()

        # (2) Update G
        z = torch.randn(x.size(0), Z_DIM, device=device)
        x_fake = G(z)
        fake_logits, _ = D(x_fake)
        loss_G = g_loss_hinge(fake_logits)  # or g_loss_bce(...)
        opt_G.zero_grad(set_to_none=True)
        loss_G.backward()
        opt_G.step()

        pbar.set_postfix({'D': f"{loss_D.item():.3f}", 'G': f"{loss_G.item():.3f}"})

    with torch.no_grad():
        samples = G(fixed_z).cpu()
    show_grid(samples, title=f"GAN samples (epoch {epoch})")



---

# Part 2 — **VAE**  *(~60 minutes)*

We will implement:
- Encoder that outputs **μ** and **log σ²**
- **Reparameterization trick**: `z = μ + σ ⊙ ε`
- Decoder that reconstructs `x̂`
- **ELBO** loss = reconstruction + KL divergence

> **Milestones**
> 1) Build **VAE module** (encode/reparameterize/decode)  
> 2) Implement **loss** (reconstruction + KL)  
> 3) Train and visualize reconstructions & random samples  
> 4) Do a **latent interpolation** between two test images


In [None]:

# ===== (TODO) VAE Architecture =====
LATENT = 16  # Try 8, 16, 32 to see effects

class VAE(nn.Module):
    def __init__(self, latent=LATENT):
        super().__init__()
        # Encoder
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),   # 14x14
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),  # 7x7
            nn.ReLU(True),
            nn.Flatten()
        )
        self.enc_fc_mu  = nn.Linear(64*7*7, latent)
        self.enc_fc_log = nn.Linear(64*7*7, latent)

        # Decoder
        self.dec_fc = nn.Linear(latent, 64*7*7)
        self.dec = nn.Sequential(
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  # 14x14
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),  # 28x28
            nn.ReLU(True),
            nn.Conv2d(16, 1, 3, 1, 1),
            nn.Tanh()
        )

    def encode(self, x):
        h = self.enc(x)
        mu = self.enc_fc_mu(h)
        logvar = self.enc_fc_log(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = (0.5*logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.dec_fc(z)
        x = self.dec(h)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        xhat = self.decode(z)
        return xhat, mu, logvar

vae = VAE().to(device)

# Quick shape tests
with torch.no_grad():
    x = xb[:2].to(device)
    xhat, mu, logvar = vae(x)
    assert xhat.shape == x.shape, f"{xhat.shape} vs {x.shape}"
    assert mu.shape[-1] == LATENT and logvar.shape[-1] == LATENT
print("✓ VAE shapes look OK.")


In [None]:

# ===== (TODO) VAE Loss & Training =====
# Hint: ELBO ≈ recon_loss + KL(q(z|x) || p(z)), with p(z)=N(0,I)
# - Use L1 or BCE for reconstruction (L1 often looks nicer on MNIST)
# - KL term: -0.5 * sum(1 + logvar - mu^2 - exp(logvar))

def vae_loss(xhat, x, mu, logvar):
    recon = F.l1_loss(xhat, x, reduction='sum') / x.size(0)  # try also BCE
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon + kl, recon, kl

opt_vae = torch.optim.Adam(vae.parameters(), lr=2e-3)

EPOCHS_VAE = 5  # Increase if you have GPU time
for epoch in range(1, EPOCHS_VAE+1):
    vae.train()
    losses = []
    pbar = tqdm(train_loader, desc=f"[VAE] Epoch {epoch}/{EPOCHS_VAE}")
    for x, _ in pbar:
        x = x.to(device)
        xhat, mu, logvar = vae(x)
        loss, rec, kl = vae_loss(xhat, x, mu, logvar)
        opt_vae.zero_grad(set_to_none=True)
        loss.backward()
        opt_vae.step()
        losses.append(loss.item())
        pbar.set_postfix({'loss': f"{np.mean(losses):.2f}"})
    # visualize reconstructions
    vae.eval()
    with torch.no_grad():
        x = xb[:16].to(device)
        xhat, _, _ = vae(x)
    show_grid(x.cpu(), title="VAE inputs")
    show_grid(xhat.cpu(), title=f"VAE reconstructions (epoch {epoch})")


In [None]:

# ===== (TODO) Sampling & Latent Interpolation =====
vae.eval()
with torch.no_grad():
    z = torch.randn(16, LATENT, device=device)
    samples = vae.decode(z).cpu()
show_grid(samples, title="VAE random samples")

# Latent interpolation between two test images
def interpolate(a, b, steps=8):
    alphas = torch.linspace(0, 1, steps, device=a.device).view(-1,1)
    return (1-alphas)*a + alphas*b

with torch.no_grad():
    x, _ = next(iter(test_loader))
    x = x.to(device)[:2]
    mu, logvar = vae.encode(x)
    z1 = mu[0]; z2 = mu[1]
    z_traj = interpolate(z1, z2, steps=16)
    interp_imgs = vae.decode(z_traj).cpu()
show_grid(interp_imgs, title="VAE latent interpolation")



---

# Part 3 — **Comparison & Proxy FID-like Metric**  *(~30 minutes)*

We will compare samples from the GAN and VAE using a **feature Fréchet distance** proxy:
1) Extract features from the **Discriminator** (penultimate conv layer)
2) Fit Gaussians to **real** vs **fake** features
3) Compute **Fréchet distance**:  
   $\|\mu_r-\mu_f\|^2 + \mathrm{Tr}(\Sigma_r + \Sigma_f - 2(\Sigma_r \Sigma_f)^{1/2})$




> This is not the official FID (which uses Inception), but behaves similarly for quick lab work.


In [None]:

# ===== (Optional TODO) Proxy FID-like Metric =====
# Requires scipy for sqrtm
try:
    from scipy import linalg
    SCIPY_OK = True
except Exception:
    SCIPY_OK = False
    print("SciPy not available — skipping proxy FID. You can !pip install scipy and re-run.")

def get_features(disc, loader, n_batches=50, use_fake=False, generator=None):
    disc.eval()
    feats = []
    with torch.no_grad():
        for i, (x, _) in enumerate(loader):
            if i >= n_batches: break
            x = x.to(device)
            if use_fake:
                z = torch.randn(x.size(0), Z_DIM, device=device)
                x = generator(z)
            _, f = disc(x)
            f = F.adaptive_avg_pool2d(f, 1).flatten(1)  # (B, C)
            feats.append(f.cpu())
    return torch.cat(feats, dim=0).numpy()

def gaussian_stats(X):
    mu = X.mean(axis=0)
    sigma = np.cov(X, rowvar=False)
    return mu, sigma

def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        covmean = linalg.sqrtm((sigma1 + np.eye(sigma1.shape[0])*eps).dot(sigma2 + np.eye(sigma2.shape[0])*eps))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2*covmean)
    return float(fid)

if SCIPY_OK:
    print("Computing real features...")
    real_feats = get_features(D, test_loader, n_batches=80, use_fake=False)
    print("Computing GAN fake features...")
    fake_feats_gan = get_features(D, test_loader, n_batches=80, use_fake=True, generator=G)
    mu_r, sig_r = gaussian_stats(real_feats)
    mu_g, sig_g = gaussian_stats(fake_feats_gan)
    fid_gan = frechet_distance(mu_r, sig_r, mu_g, sig_g)
    print(f"Proxy FID (GAN vs real): {fid_gan:.2f}")

    # VAE samples
    vae.eval()
    all_vae = []
    with torch.no_grad():
        for _ in range(80):
            z = torch.randn(BATCH, LATENT, device=device)
            all_vae.append(vae.decode(z).cpu())
    all_vae = torch.cat(all_vae, dim=0)[:len(real_feats)]

    fake_feats_vae = []
    print("Computing VAE fake features...")
    with torch.no_grad():
        for i in range(0, len(all_vae), BATCH):
            batch = all_vae[i:i+BATCH].to(device)
            _, f = D(batch)
            f = F.adaptive_avg_pool2d(f, 1).flatten(1)
            fake_feats_vae.append(f.cpu())
    fake_feats_vae = torch.cat(fake_feats_vae, dim=0).numpy()
    mu_v, sig_v = gaussian_stats(fake_feats_vae)
    fid_vae = frechet_distance(mu_r, sig_r, mu_v, sig_v)
    print(f"Proxy FID (VAE vs real): {fid_vae:.2f}")



---

## 📝 Final Reflection (to submit)

1. Copy all the generated outputs, don't forget to label them (e.g   Fashion-MNIST, GAN, Z_DIM=128, EPOCH=... )

2. Include image grids:
   - GAN samples (best epoch)
   - VAE reconstructions
   - VAE latent interpolation

3. Include your **proxy FID-like** numbers for GAN and VAE.

4. Answer briefly:
   - What hyperparameters most influenced **GAN stability** in your runs?
   - Evidence of **mode collapse** (if any)? What helped?
   - How did **latent dim** affect VAE reconstructions and samples?
   - One idea to combine benefits of both models (e.g., **VAE-GAN**).
