
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](
https://colab.research.google.com/github/mhuertascompany/euclid-school-2025/blob/main/Y2/notebooks/MNIST_VAE.ipynb)

# A Minimal, Explicit Variational Autoencoder (VAE)

## Rodolphe Cledassou School 2025

> Marc Huertas-Company & Hubert Bretonnière 

**Goal:** teach the VAE objective and implementation with clear, *separable* loss terms (reconstruction and KL).  
We’ll train a simple VAE on MNIST with either a Bernoulli (BCE) or Gaussian (MSE) decoder and inspect each term.

**You will see:**
- The ELBO written explicitly and how each term maps to code
- The reparameterization trick
- A clean training loop logging `recon_loss`, `kl_loss`, and `elbo`
- Sampling from the prior and visualizing reconstructions / latent space

---
**ELBO (per-sample):**
$$\begin{equation}
\mathcal{L}(\theta,\phi;x)
= \mathbb{E}_{q_\phi(z\mid x)}\big[\log p_\theta(x\mid z)\big]
- \mathrm{KL}\!\big(q_\phi(z\mid x)\,\|\,p(z)\big).
\end{equation}$$

**VAE loss (minimize negative ELBO):**
$$\begin{equation}
\mathcal{L}_\text{VAE}(x)
= -\,\mathbb{E}_{q_\phi}\log p_\theta(x\mid z)
+ \mathrm{KL}\!\big(q_\phi(z\mid x)\,\|\,p(z)\big).
\end{equation}$$


In [None]:
# --- EUCLID SCHOOL: LIGHT BOOTSTRAP (no data) -------------------------------
# Detect Colab, (optionally) install minimal deps, (optionally) clone the repo,
# and print device info. It does NOT download any dataset.
# ----------------------------------------------------------------------------
import os, sys, subprocess
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# --- Colab detection
IN_COLAB = False
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    pass

# --- Settings (edit if needed)
INSTALL_DEPS = True                # set False if you want to skip pip installs on Colab
PIP_PKGS = [
    # Keep small. Colab already has torch + CUDA.
    "datasets==4.*", "transformers==4.*", "timm==1.*", "albumentations==2.*",
    "lightning==2.*", "pytorch-lightning==2.*", "einops==0.*",
    "pyarrow", "seaborn", "umap-learn", "nflows",
    "tensorboard", "tqdm", "safetensors", "opencv-python"
]

# If your notebook relies on repo-relative paths, you can enable this:
CLONE_REPO = False                 # set True only if needed
REPO_URL   = "https://github.com/mhuertascompany/euclid-school-2025.git"
REPO_DIR   = "/content/euclid-school-2025"
SUBDIR     = None                  # e.g., "Y1/notebooks" or "Y2/xyz"

def pip_install(pkgs):
    if not pkgs: return
    cmd = [sys.executable, "-m", "pip", "install", "-q", "--upgrade"] + list(pkgs)
    subprocess.run(cmd, check=True)

if IN_COLAB:
    print("Running on Google Colab ✓")
    if INSTALL_DEPS:
        print("Installing minimal pip deps…")
        pip_install(PIP_PKGS)

    if CLONE_REPO:
        if not os.path.isdir(REPO_DIR):
            print(f"Cloning {REPO_URL} …")
            subprocess.run(["git", "clone", "-q", REPO_URL, REPO_DIR], check=True)
        if SUBDIR:
            os.chdir(os.path.join(REPO_DIR, SUBDIR))
            print("Working directory:", os.getcwd())

    # Device info
    try:
        subprocess.run(["nvidia-smi"], check=False)
    except Exception:
        pass
else:
    print("Not running on Colab (no action).")
# ----------------------------------------------------------------------------
import torch
print("\nPyTorch:", torch.__version__)
device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Not running on Colab (no action).

PyTorch: 2.5.1
Using device: mps


In [None]:

# !pip install torch torchvision  # uncomment if needed

import math, os, time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

device = (
    torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cuda") if torch.cuda.is_available()
    else torch.device("cpu")
)
print("Using device:", device)

latent_dim = 2
hidden_dim = 512
batch_size = 128
epochs = 5
lr = 1e-3
decoder_likelihood = "gaussian"  # 'bernoulli' or 'gaussian'
fixed_gaussian_sigma = 0.1
beta = 1.0

save_dir = "vae_outputs"
os.makedirs(save_dir, exist_ok=True)


Device: cpu


In [2]:

transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
print("Train size:", len(train_ds), " Test size:", len(test_ds))


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:05<00:00, 1.82MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 250kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 1.75MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 1.27MB/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Train size: 60000  Test size: 10000






## Model
MLP encoder/decoder:
- Encoder: $$x \mapsto (\mu_\phi(x), \log\sigma^2_\phi(x))$$
- Reparameterization: $$z=\mu+\sigma\odot\varepsilon,\ \varepsilon\sim\mathcal N(0,I)$$
- Decoder: $$z \mapsto \hat x$$ (logits for Bernoulli; mean for Gaussian)


In [3]:

class Encoder(nn.Module):
    def __init__(self, latent_dim=2, hidden_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.to_mu = nn.Linear(hidden_dim, latent_dim)
        self.to_logvar = nn.Linear(hidden_dim, latent_dim)
    def forward(self, x):
        h = self.net(x)
        mu = self.to_mu(h)
        logvar = self.to_logvar(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim=2, hidden_dim=512, likelihood="bernoulli"):
        super().__init__()
        self.likelihood = likelihood
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.to_out = nn.Linear(hidden_dim, 28*28)
    def forward(self, z):
        h = self.net(z)
        x_flat = self.to_out(h)
        return x_flat.view(-1, 1, 28, 28)

class VAE(nn.Module):
    def __init__(self, latent_dim=2, hidden_dim=512, likelihood="bernoulli"):
        super().__init__()
        self.enc = Encoder(latent_dim, hidden_dim)
        self.dec = Decoder(latent_dim, hidden_dim, likelihood)
        self.likelihood = likelihood
    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps
    def forward(self, x):
        mu, logvar = self.enc(x)
        z = self.reparameterize(mu, logvar)
        x_logits_or_mean = self.dec(z)
        return x_logits_or_mean, mu, logvar, z



## Loss terms (explicit)

**KL (diag Gaussian $$q$$ vs $$N(0,I)$$)**
$$\mathrm{KL} = -\tfrac{1}{2}\sum_j \big(1+\log\sigma_j^2 - \mu_j^2 - \sigma_j^2\big).$$

**Reconstruction**
- Bernoulli: `BCEWithLogitsLoss` gives $$-\log p_\theta(x|z)$$ summed over pixels.
- Gaussian (fixed $$\sigma^2$$): use $$\frac{1}{2\sigma^2}\|x-\mu_\theta(z)\|^2$$ (dropping constants).


In [5]:

def kl_divergence_diag_gaussian(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)

bce = nn.BCEWithLogitsLoss(reduction='sum')

def recon_loss(x_logits_or_mean, x, likelihood="bernoulli", sigma=0.1):
    if likelihood == "bernoulli":
        return bce(x_logits_or_mean, x)
    elif likelihood == "gaussian":
        mse = F.mse_loss(x_logits_or_mean, x, reduction='sum')
        return (1.0 / (2 * sigma * sigma)) * mse
    else:
        raise ValueError

def elbo_loss(x, x_logits_or_mean, mu, logvar, likelihood="bernoulli", beta=1.0, sigma=0.1):
    rl = recon_loss(x_logits_or_mean, x, likelihood, sigma)
    kl = kl_divergence_diag_gaussian(mu, logvar).sum()
    elbo = - rl - beta * kl
    return rl, kl, elbo


In [6]:

vae = VAE(latent_dim=latent_dim, hidden_dim=hidden_dim, likelihood=decoder_likelihood).to(device)
opt = torch.optim.Adam(vae.parameters(), lr=lr)

def train_one_epoch(epoch):
    vae.train()
    total_rl, total_kl, total_elbo = 0.0, 0.0, 0.0
    for x, _ in train_loader:
        x = x.to(device)
        x_logits_or_mean, mu, logvar, z = vae(x)
        rl, kl, elbo = elbo_loss(x, x_logits_or_mean, mu, logvar,
                                 likelihood=decoder_likelihood, beta=beta, sigma=fixed_gaussian_sigma)
        loss = -elbo
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_rl += rl.item()
        total_kl += kl.item()
        total_elbo += elbo.item()
    n = len(train_loader.dataset)
    print(f"[Epoch {epoch}] recon={total_rl/n:.3f}  kl={total_kl/n:.3f}  -elbo={-total_elbo/n:.3f}")

@torch.no_grad()
def evaluate(split="test", save=True):
    vae.eval()
    loader = test_loader if split=="test" else train_loader
    total_rl, total_kl, total_elbo = 0.0, 0.0, 0.0
    for x, _ in loader:
        x = x.to(device)
        x_logits_or_mean, mu, logvar, z = vae(x)
        rl, kl, elbo = elbo_loss(x, x_logits_or_mean, mu, logvar,
                                 likelihood=decoder_likelihood, beta=beta, sigma=fixed_gaussian_sigma)
        total_rl += rl.item(); total_kl += kl.item(); total_elbo += elbo.item()
    n = len(loader.dataset)
    print(f"[{split.upper()}] recon={total_rl/n:.3f}  kl={total_kl/n:.3f}  -elbo={-total_elbo/n:.3f}")

    if save:
        # save reconstructions and prior samples
        from torchvision import utils
        x, _ = next(iter(loader))
        x = x.to(device)[:64]
        x_logits_or_mean, _, _, _ = vae(x)
        if decoder_likelihood == "bernoulli":
            x_recon = torch.sigmoid(x_logits_or_mean)
        else:
            x_recon = x_logits_or_mean.clamp(0,1)
        grid_true = utils.make_grid(x.cpu(), nrow=8)
        grid_reco = utils.make_grid(x_recon.cpu(), nrow=8)
        utils.save_image(grid_true, os.path.join(save_dir, "true_grid.png"))
        utils.save_image(grid_reco, os.path.join(save_dir, "reco_grid.png"))
        z = torch.randn(64, latent_dim, device=device)
        x_logits_or_mean = vae.dec(z)
        if decoder_likelihood == "bernoulli":
            x_sample = torch.sigmoid(x_logits_or_mean)
        else:
            x_sample = x_logits_or_mean.clamp(0,1)
        grid_samp = utils.make_grid(x_sample.cpu(), nrow=8)
        utils.save_image(grid_samp, os.path.join(save_dir, "samples_grid.png"))
        print("Saved true_grid.png, reco_grid.png, samples_grid.png")


In [7]:

for epoch in range(1, epochs+1):
    train_one_epoch(epoch)
    evaluate("test")


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[Epoch 1] recon=1982.295  kl=8.637  -elbo=1990.932


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[TEST] recon=1733.149  kl=9.423  -elbo=1742.572


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


Saved true_grid.png, reco_grid.png, samples_grid.png


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[Epoch 2] recon=1669.392  kl=9.768  -elbo=1679.160


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[TEST] recon=1611.246  kl=10.680  -elbo=1621.926


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


Saved true_grid.png, reco_grid.png, samples_grid.png


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[Epoch 3] recon=1583.833  kl=10.018  -elbo=1593.852


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[TEST] recon=1561.069  kl=10.411  -elbo=1571.479


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


Saved true_grid.png, reco_grid.png, samples_grid.png


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[Epoch 4] recon=1541.128  kl=10.032  -elbo=1551.161


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[TEST] recon=1523.125  kl=10.364  -elbo=1533.489


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


Saved true_grid.png, reco_grid.png, samples_grid.png


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[Epoch 5] recon=1509.375  kl=10.080  -elbo=1519.454


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


[TEST] recon=1500.245  kl=10.070  -elbo=1510.314


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


Saved true_grid.png, reco_grid.png, samples_grid.png



## Latent space plot (for \(d_z=2\))


In [8]:

import matplotlib.pyplot as plt

@torch.no_grad()
def plot_latent_space(split="test", max_points=5000):
    vae.eval()
    loader = test_loader if split=="test" else train_loader
    mus, ys, count = [], [], 0
    for x, y in loader:
        x = x.to(device)
        mu, logvar = vae.enc(x)
        mus.append(mu.cpu()); ys.append(y)
        count += x.size(0)
        if count >= max_points: break
    mu_all = torch.cat(mus, dim=0).numpy()
    y_all = torch.cat(ys, dim=0).numpy()
    plt.figure(figsize=(5,5))
    sc = plt.scatter(mu_all[:,0], mu_all[:,1], c=y_all, s=5, cmap="tab10")
    plt.colorbar(sc, ticks=range(10))
    plt.title("Latent means μ(x)"); plt.xlabel("z1"); plt.ylabel("z2")
    out = os.path.join(save_dir, "latent_scatter.png")
    plt.tight_layout(); plt.savefig(out, dpi=150); plt.close()
    print("Saved", out)

plot_latent_space()


  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(
  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /Users/marchuertascompany/soft/miniforge3/envs/spender/lib/python3.10/site-packages/torchvision/image.so
  warn(


Saved vae_outputs/latent_scatter.png
