# VAE Video Compression (Jupyter Project)
Train a Variational Autoencoder (VAE) for video compression and compare against H.264/HEVC.

**What this notebook does:**
1. Installs/validates dependencies
2. Loads your real dataset from `data/train` and `data/val`
3. Defines a 3D-Conv VAE (temporal model)
4. Trains multiple RD points (λ sweep)
5. Evaluates VAE rate (KL→bpp) and distortion (PSNR/SSIM)
6. Encodes baselines with ffmpeg (x264/x265), evaluates true bits
7. Plots RD curves and computes category-wise BD-Rate

## 0) Optional: Install/verify dependencies

In [1]:
# If needed, uncomment to install/upgrade.
# !python -m pip install --upgrade pip
# CPU PyTorch (macOS typical)
# !pip install torch torchvision torchaudio
# Core deps
# !pip install numpy opencv-python matplotlib tqdm scikit-image pyyaml scipy
# Verify ffmpeg (install via brew/apt if missing)
# !ffmpeg -version


## 1) Imports & helpers

In [4]:
import os, glob, math, json, sys, time, shutil
from typing import List, Tuple, Dict
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim_metric

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)


Device: cpu


## 2) Configure dataset paths

In [33]:
import os
import urllib.request

# Create folders
os.makedirs("data/train", exist_ok=True)
os.makedirs("data/val", exist_ok=True)

# Tiny MP4 clips for quick training/testing
videos = {
    "data/val": [
        "https://sample-videos.com/video321/mp4/240/big_buck_bunny_240p_1mb.mp4",
        "https://sample-videos.com/video321/mp4/240/big_buck_bunny_240p_2mb.mp4"
    ],
    "data/train": [
        "https://sample-videos.com/video321/mp4/240/big_buck_bunny_240p_5mb.mp4",
        "https://sample-videos.com/video321/mp4/240/big_buck_bunny_240p_10mb.mp4"
       
    ]
}

# Download function
def download_file(url, dest_folder):
    filename = os.path.join(dest_folder, os.path.basename(url.split("?")[0]))
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filename)
    else:
        print(f"Already exists: {filename}")

# Fetch all videos
for folder, urls in videos.items():
    for url in urls:
        download_file(url, folder)

# Summary
train_files = os.listdir("data/train")
val_files = os.listdir("data/val")
print(f"\nTrain videos: {len(train_files)} → {train_files}")
print(f"Val videos  : {len(val_files)} → {val_files}")


Downloading data/val/big_buck_bunny_240p_1mb.mp4...
Downloading data/val/big_buck_bunny_240p_2mb.mp4...
Downloading data/train/big_buck_bunny_240p_5mb.mp4...
Downloading data/train/big_buck_bunny_240p_10mb.mp4...

Train videos: 3 → ['.DS_Store', 'big_buck_bunny_240p_5mb.mp4', 'big_buck_bunny_240p_10mb.mp4']
Val videos  : 3 → ['.DS_Store', 'big_buck_bunny_240p_2mb.mp4', 'big_buck_bunny_240p_1mb.mp4']


In [41]:
# Fix A: helper to make time dims match (center-crop longer tensor)
import torch.nn.functional as F
import math, numpy as np

def align_time(a, b):
    """
    a,b: (B,C,T,H,W). Center-crop longer one along T so both match.
    Returns tensors with identical T.
    """
    Ta, Tb = a.shape[2], b.shape[2]
    if Ta == Tb:
        return a, b
    T = min(Ta, Tb)
    def crop_t(x, T):
        start = (x.shape[2] - T) // 2
        return x[:, :, start:start+T]
    if Ta != T: a = crop_t(a, T)
    if Tb != T: b = crop_t(b, T)
    return a, b


In [49]:
# --- STABILITY PATCH for macOS/Jupyter DataLoader worker crashes ---

import os, cv2, torch, torch.multiprocessing as mp

# OpenCV + multiprocessing can crash on macOS; keep things single-threaded.
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
cv2.setNumThreads(0)

# Make sure PyTorch uses a safe start method in notebooks
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass  # already set

# Helper flags so we don't forget
DATALOADER_KW = dict(num_workers=0, persistent_workers=False, pin_memory=False)
print("DataLoader workers disabled:", DATALOADER_KW)


DataLoader workers disabled: {'num_workers': 0, 'persistent_workers': False, 'pin_memory': False}


In [50]:
from torch.utils.data import DataLoader

train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, **DATALOADER_KW)
val_ld   = DataLoader(val_ds,   batch_size=1, shuffle=False, **DATALOADER_KW)


NameError: name 'train_ds' is not defined

## 3) Dataset loader

In [42]:
# Fix B: robust __getitem__ that guarantees exactly `self.frames` frames
import cv2, torch, numpy as np

class VideoFolderDataset(VideoFolderDataset):  # subclass/override
    def __getitem__(self, idx):
        vi, s = self.index[idx]
        path = self.items[vi]
        cap = cv2.VideoCapture(path)
        frames = []
        cap.set(cv2.CAP_PROP_POS_FRAMES, s)
        for _ in range(self.frames):
            ok, f = cap.read()
            if not ok:
                break
            f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
            h, w = f.shape[:2]
            Ht, Wt = self.size
            scale = max(Ht/h, Wt/w)
            nh, nw = int(round(h*scale)), int(round(w*scale))
            f = cv2.resize(f, (nw, nh), interpolation=cv2.INTER_AREA)
            y0 = (nh - Ht)//2; x0 = (nw - Wt)//2
            f = f[y0:y0+Ht, x0:x0+Wt]
            frames.append(f)
        cap.release()

        # Ensure at least 1 frame
        if len(frames) == 0:
            cap = cv2.VideoCapture(path)
            ok, f = cap.read()
            cap.release()
            if not ok:
                raise RuntimeError(f"Could not read any frame from {path}")
            f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
            h, w = f.shape[:2]
            Ht, Wt = self.size
            scale = max(Ht/h, Wt/w)
            nh, nw = int(round(h*scale)), int(round(w*scale))
            f = cv2.resize(f, (nw, nh), interpolation=cv2.INTER_AREA)
            y0 = (nh - Ht)//2; x0 = (nw - Wt)//2
            f = f[y0:y0+Ht, x0:x0+Wt]
            frames = [f]

        # Pad or crop to exactly self.frames
        while len(frames) < self.frames:
            frames.append(frames[-1])
        if len(frames) > self.frames:
            start = (len(frames) - self.frames) // 2
            frames = frames[start:start+self.frames]

        arr = np.stack(frames, axis=0)  # T,H,W,3
        arr = torch.from_numpy(arr).permute(3,0,1,2).float()/255.0  # C,T,H,W
        return arr, os.path.basename(path)


## 4) VAE model

In [43]:
# Fix C: patched train/eval that call align_time() before loss/metrics
import torch, os
from torch.utils.data import DataLoader

def _psnr(x, y, eps=1e-8):
    mse = F.mse_loss(x, y).item()
    return 10.0 * math.log10(1.0 / (mse + eps))

def train_vae(lambda_mse=0.01, beta_kl=1.0, frames=4, size=(128,128), window_step=4,
              epochs=3, batch_size=1, lr=1e-4, save_dir='runs/vae_quickfix',
              device=('cuda' if torch.cuda.is_available() else 'cpu')):
    # reuse your already-defined classes
    model = VAEVideo().to(device)
    os.makedirs(save_dir, exist_ok=True)

    train_ds = VideoFolderDataset(DATA_TRAIN, frames=frames, size=size, window_step=window_step)
    val_ds   = VideoFolderDataset(DATA_VAL,   frames=frames, size=size, window_step=window_step)
    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_ld   = DataLoader(val_ds, batch_size=1, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    best_psnr = -1e9

    for epoch in range(1, epochs+1):
        model.train()
        for x, _ in train_ld:
            x = x.to(device, non_blocking=True)
            xrec, kl = model(x)
            x_adj, xrec_adj = align_time(x, xrec)              # <<< fix
            loss_mse = F.mse_loss(x_adj, xrec_adj)
            loss = lambda_mse * loss_mse + beta_kl * kl
            opt.zero_grad(); loss.backward(); opt.step()

        # quick val
        model.eval(); vs=[]
        with torch.no_grad():
            for x, _ in val_ld:
                x = x.to(device)
                xrec, _ = model(x)
                x_adj, xrec_adj = align_time(x, xrec)          # <<< fix
                vs.append(_psnr(x_adj, xrec_adj))
        cur = float(np.mean(vs)) if vs else -1e9
        torch.save(model.state_dict(), os.path.join(save_dir, 'last.pt'))
        if cur > best_psnr:
            best_psnr = cur
            torch.save(model.state_dict(), os.path.join(save_dir, 'best.pt'))
        print(f"Epoch {epoch}/{epochs}  Val PSNR: {cur:.2f} (best {best_psnr:.2f})")

    return os.path.join(save_dir, 'best.pt')

def eval_vae(ckpt_path, frames=4, size=(128,128), window_step=4,
             device=('cuda' if torch.cuda.is_available() else 'cpu')):
    # reuse your utils from the notebook
    ds = VideoFolderDataset(DATA_VAL, frames=frames, size=size, window_step=window_step)
    ld = DataLoader(ds, batch_size=1, shuffle=False)

    model = VAEVideo().to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    model.eval()
    points = []
    with torch.no_grad():
        for x, name in ld:
            x = x.to(device)
            xrec, kl = model(x)
            x_adj, xrec_adj = align_time(x, xrec)              # <<< fix
            m = {
                "name": name[0].rsplit('.',1)[0],
                "psnr": float(_psnr(x_adj, xrec_adj)),
                "ssim": float(ssim_torch(x_adj, xrec_adj)),
                "rate_bpp": float(kl_bits_per_pixel(kl, x_adj.shape)),  # use adjusted T
            }
            points.append(m)
    return points


## 5) Metrics & utilities

In [36]:
def mse_loss(x, y): return F.mse_loss(x, y)

def psnr(x, y, eps=1e-8):
    mse = mse_loss(x, y).item()
    return 10.0 * math.log10(1.0 / (mse + eps))

def ssim_torch(x, y):
    w = torch.tensor([0.299, 0.587, 0.114], device=x.device).view(1,3,1,1,1)
    xg = (x*w).sum(1, keepdim=True)
    yg = (y*w).sum(1, keepdim=True)
    C1 = 0.01**2; C2 = 0.03**2
    mu_x = F.avg_pool3d(xg, kernel_size=(x.shape[2],11,11), stride=1, padding=(0,5,5))
    mu_y = F.avg_pool3d(yg, kernel_size=(y.shape[2],11,11), stride=1, padding=(0,5,5))
    sigma_x  = F.avg_pool3d(xg*xg, kernel_size=(x.shape[2],11,11), stride=1, padding=(0,5,5)) - mu_x**2
    sigma_y  = F.avg_pool3d(yg*yg, kernel_size=(y.shape[2],11,11), stride=1, padding=(0,5,5)) - mu_y**2
    sigma_xy = F.avg_pool3d(xg*yg, kernel_size=(x.shape[2],11,11), stride=1, padding=(0,5,5)) - mu_x*mu_y
    ssim_map = ((2*mu_x*mu_y + C1)*(2*sigma_xy + C2))/((mu_x**2 + mu_y**2 + C1)*(sigma_x + sigma_y + C2))
    return ssim_map.mean()

def kl_bits_per_pixel(kl_scalar, x_shape):
    B,C,T,H,W = x_shape
    n_pix = B*T*H*W
    bits = float(kl_scalar) / math.log(2)
    return bits / (n_pix + 1e-12)


## 6) Train (single λ)

In [37]:
def train_vae(lambda_mse=0.01, beta_kl=1.0, frames=8, size=(256,256), window_step=8,
             epochs=5, batch_size=2, lr=1e-4, save_dir='runs/vae', device=device):
    os.makedirs(save_dir, exist_ok=True)
    train_ds = VideoFolderDataset(DATA_TRAIN, frames=frames, size=size, window_step=window_step)
    val_ds   = VideoFolderDataset(DATA_VAL,   frames=frames, size=size, window_step=window_step)
    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_ld   = DataLoader(val_ds, batch_size=1, shuffle=False)
    model = VAEVideo().to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    best_psnr = -1e9
    for epoch in range(1, epochs+1):
        model.train(); pbar = tqdm(train_ld, desc=f'Epoch {epoch}/{epochs}')
        for x,_ in pbar:
            x=x.to(device); xrec, kl = model(x)
            loss_mse = mse_loss(x,xrec); loss = lambda_mse*loss_mse + beta_kl*kl
            opt.zero_grad(); loss.backward(); opt.step()
            pbar.set_postfix(loss=float(loss.item()), mse=float(loss_mse.item()), kl=float(kl.item()))
        model.eval(); psnrs=[]
        with torch.no_grad():
            for x,_ in DataLoader(val_ds, batch_size=1, shuffle=False):
                x=x.to(device); xrec,_=model(x); psnrs.append(psnr(x,xrec))
        cur=float(np.mean(psnrs));
        torch.save(model.state_dict(), os.path.join(save_dir,'last.pt'))
        if cur>best_psnr:
            best_psnr=cur; torch.save(model.state_dict(), os.path.join(save_dir,'best.pt'))
        print(f'Val PSNR: {cur:.2f} (best {best_psnr:.2f})')
    return os.path.join(save_dir,'best.pt')


In [51]:
# ---- SPEED MODE PATCH ----
import os, math, numpy as np, torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Pick the fastest device available
device = (
    "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)
print("Using device:", device)

# Reuse your helpers if already defined: align_time, ssim_torch, kl_bits_per_pixel, VideoFolderDataset, VAEVideo

def _psnr(x, y, eps=1e-8):
    mse = F.mse_loss(x, y).item()
    return 10.0 * math.log10(1.0 / (mse + eps))

def train_vae_fast(
    lambda_mse=0.01,
    beta_kl=1.0,
    frames=4,
    size=(96,96),
    window_step=12,
    epochs=3,
    batch_size=1,
    lr=1e-4,
    channels=32,        # smaller model width
    latent_dim=96,      # smaller latent
    val_max_batches=4,  # validate on only a few batches
    num_workers=2,
    save_dir="runs/fast"
):
    os.makedirs(save_dir, exist_ok=True)

    # datasets / loaders
    train_ds = VideoFolderDataset(DATA_TRAIN, frames=frames, size=size, window_step=window_step)
    val_ds   = VideoFolderDataset(DATA_VAL,   frames=frames, size=size, window_step=window_step)
    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device!="cpu"))
    val_ld   = DataLoader(val_ds,   batch_size=1, shuffle=False, num_workers=max(1,num_workers-1), pin_memory=(device!="cpu"))

    # smaller model
    model = VAEVideo(in_ch=3, latent_dim=latent_dim, base=channels, groups=8).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)

    best_psnr = -1e9
    for epoch in range(1, epochs+1):
        model.train()
        for x, _ in train_ld:
            x = x.to(device, non_blocking=True)
            xrec, kl = model(x)
            # align_time() should already be defined in your notebook; if not, just remove this line
            x_adj, xrec_adj = align_time(x, xrec) if 'align_time' in globals() else (x, xrec)
            loss_mse = F.mse_loss(x_adj, xrec_adj)
            loss = lambda_mse * loss_mse + beta_kl * kl
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

        # quick/cheap validation
        model.eval()
        vs = []
        with torch.no_grad():
            for i, (x, _) in enumerate(val_ld):
                if i >= val_max_batches: break
                x = x.to(device)
                xrec, _ = model(x)
                x_adj, xrec_adj = align_time(x, xrec) if 'align_time' in globals() else (x, xrec)
                vs.append(_psnr(x_adj, xrec_adj))
        cur = float(np.mean(vs)) if vs else -1e9
        torch.save(model.state_dict(), os.path.join(save_dir, "last.pt"))
        if cur > best_psnr:
            best_psnr = cur
            torch.save(model.state_dict(), os.path.join(save_dir, "best.pt"))
        print(f"[FAST] Epoch {epoch}/{epochs}  Val PSNR (subset): {cur:.2f}  (best {best_psnr:.2f})")

    return os.path.join(save_dir, "best.pt")

def eval_vae_fast(
    ckpt_path,
    frames=4,
    size=(96,96),
    window_step=12,
    num_workers=1
):
    ds = VideoFolderDataset(DATA_VAL, frames=frames, size=size, window_step=window_step)
    ld = DataLoader(ds, batch_size=1, shuffle=False, num_workers=num_workers)

    model = VAEVideo(in_ch=3, latent_dim=96, base=32, groups=8).to(device)  # must match train_vae_fast
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    model.eval()

    points = []
    with torch.no_grad():
        for x, name in ld:
            x = x.to(device)
            xrec, kl = model(x)
            x_adj, xrec_adj = align_time(x, xrec) if 'align_time' in globals() else (x, xrec)
            points.append({
                "name": name[0].rsplit('.',1)[0],
                "psnr": float(_psnr(x_adj, xrec_adj)),
                "ssim": float(ssim_torch(x_adj, xrec_adj)) if 'ssim_torch' in globals() else None,
                "rate_bpp": float(kl_bits_per_pixel(kl, x_adj.shape)) if 'kl_bits_per_pixel' in globals() else None,
            })
    return points


Using device: mps


## 7) Train an RD sweep (λ values)

In [52]:
# λ sweep in speed mode (much faster)
lambdas = [0.003, 0.01, 0.03]
checkpoints = {}
for lam in lambdas:
    print(f"\n=== FAST mode: λ={lam} ===")
    ckpt = train_vae_fast(
        lambda_mse=lam,
        epochs=3,
        frames=4,
        size=(96,96),
        window_step=12,
        channels=32,
        latent_dim=96,
        save_dir=f"runs/fast_lam{lam}"
    )
    checkpoints[lam] = ckpt

# quick eval
vae_points = []
for lam, ckpt in checkpoints.items():
    vae_points.extend(eval_vae_fast(ckpt, frames=4, size=(96,96), window_step=12))
print("VAE points:", len(vae_points))



=== FAST mode: λ=0.003 ===


Traceback (most recent call last):
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "<string>", line 1, in <module>
  File "/Users/mawahid/Documents/vae_video_compression/.conda/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
  File "/Users/mawahid/Documents/vae_video_compression/.conda/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
    exitcode = _main(fd, parent_sentinel)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/mawahid/Documents/vae_video_compression/.conda/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
  File "/Users/mawahid/Documents/vae_video_compression/.conda/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
        self = reduction.pickle.load(from_parent)self = reduction.pickle.load(from_parent)

                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

RuntimeError: DataLoader worker (pid(s) 62179, 62180) exited unexpectedly

## 8) Evaluate VAE models

In [None]:
def eval_vae(ckpt_path, frames=8, size=(256,256), window_step=8, device=device):
    ds = VideoFolderDataset(DATA_VAL, frames=frames, size=size, window_step=window_step)
    ld = DataLoader(ds, batch_size=1, shuffle=False)
    model = VAEVideo().to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    model.eval(); points=[]
    with torch.no_grad():
        for x,name in tqdm(ld, desc='Eval VAE'):
            x=x.to(device); xrec,kl=model(x)
            points.append({'name': name[0].rsplit('.',1)[0], 'psnr': float(psnr(x,xrec)), 'ssim': float(ssim_torch(x,xrec)), 'rate_bpp': float(kl_bits_per_pixel(kl, x.shape))})
    return points

vae_points=[]
for lam,ckpt in checkpoints.items():
    vae_points.extend(eval_vae(ckpt, frames=4, size=(128,128), window_step=4))
len(vae_points)


## 9) Baselines — ffmpeg encodes

In [None]:
import subprocess

def encode_dir(codec_label, crfs, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    videos = sorted(glob.glob(os.path.join(DATA_VAL, '*')))
    for v in videos:
        base = os.path.splitext(os.path.basename(v))[0]
        for crf in crfs:
            if codec_label=='x264':
                cmd=['ffmpeg','-y','-i',v,'-c:v','libx264','-preset','medium','-crf',str(crf),'-pix_fmt','yuv420p', os.path.join(out_dir,f'{base}_crf{crf}.mp4')]
            else:
                cmd=['ffmpeg','-y','-i',v,'-c:v','libx265','-preset','medium','-crf',str(crf),'-pix_fmt','yuv420p', os.path.join(out_dir,f'{base}_crf{crf}.mp4')]
            print('Running:', ' '.join(cmd)); subprocess.run(cmd, check=True)

encode_dir('x264', [18,22,26,30,34], 'outputs/x264')
encode_dir('x265', [20,24,28,32,36], 'outputs/x265')


## 10) Evaluate baselines

In [None]:
def decode_frames(path):
    cap=cv2.VideoCapture(path); frames=[]; ok=True
    while ok:
        ok,f=cap.read();
        if ok: frames.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
    cap.release(); return frames

def eval_ffmpeg(enc_root):
    points=[]
    ref_videos = sorted(glob.glob(os.path.join(DATA_VAL, '*')))
    for ref in ref_videos:
        base=os.path.splitext(os.path.basename(ref))[0]
        cands=sorted(glob.glob(os.path.join(enc_root, f'{base}*.*')))
        if not cands: continue
        ref_fs=decode_frames(ref); H,W=ref_fs[0].shape[:2]
        for enc in cands:
            enc_fs=decode_frames(enc); n=min(len(ref_fs), len(enc_fs))
            ps=[]
            for i in range(n):
                x=ref_fs[i].astype(np.float32)/255.0; y=enc_fs[i].astype(np.float32)/255.0
                mse=np.mean((x-y)**2)+1e-8; ps.append(10*np.log10(1.0/mse))
            bits=os.path.getsize(enc)*8; bpp=bits/(H*W*n)
            points.append({'name': base, 'file': os.path.basename(enc), 'psnr': float(np.mean(ps)), 'rate_bpp': float(bpp)})
    return points

x264_points=eval_ffmpeg('outputs/x264'); x265_points=eval_ffmpeg('outputs/x265')
len(x264_points), len(x265_points)


## 11) Plot RD curves

In [None]:
def plot_rd(vae_pts, x264_pts, x265_pts):
    def _xy(points):
        xs=[p['rate_bpp'] for p in points]; ys=[p['psnr'] for p in points]
        order=np.argsort(xs); return np.array(xs)[order], np.array(ys)[order]
    plt.figure()
    vx,vy=_xy(vae_pts); plt.scatter(vx,vy,label='VAE (KL→bpp)', marker='o')
    x4,y4=_xy(x264_pts); plt.plot(x4,y4,'-o',label='H.264 (x264)')
    x5,y5=_xy(x265_pts); plt.plot(x5,y5,'-o',label='H.265 (x265)')
    plt.xlabel('Rate (bits/pixel)'); plt.ylabel('PSNR (dB)'); plt.title('Rate–Distortion')
    plt.grid(True); plt.legend(); plt.show()

plot_rd(vae_points, x264_points, x265_points)


## 12) Category-wise BD-Rate (requires scipy)

In [None]:
# !pip install scipy
import numpy as np
from collections import defaultdict
from scipy.interpolate import interp1d

def bdrate(rate1, psnr1, rate2, psnr2):
    r1, p1 = np.array(rate1, float), np.array(psnr1, float)
    r2, p2 = np.array(rate2, float), np.array(psnr2, float)
    m1=(r1>0)&np.isfinite(r1)&np.isfinite(p1); m2=(r2>0)&np.isfinite(r2)&np.isfinite(p2)
    r1,p1=r1[m1],p1[m1]; r2,p2=r2[m2],p2[m2]
    if len(r1)<2 or len(r2)<2: return None
    pmin,pmax=max(p1.min(),p2.min()), min(p1.max(),p2.max())
    if not np.isfinite(pmin) or not np.isfinite(pmax) or pmax<=pmin: return None
    f1=interp1d(p1, np.log(r1), kind='linear', fill_value='extrapolate')
    f2=interp1d(p2, np.log(r2), kind='linear', fill_value='extrapolate')
    xs=np.linspace(pmin,pmax,200); r1i=np.exp(f1(xs)); r2i=np.exp(f2(xs))
    dr=(np.trapz(r1i,xs)/(pmax-pmin))/(np.trapz(r2i,xs)/(pmax-pmin)) - 1.0
    return float(dr*100.0)

# Edit your categories mapping below
categories={}

vae_by=defaultdict(list)
for p in vae_points: vae_by[p['name']].append(p)

x264_by=defaultdict(list)
for p in x264_points: x264_by[p['name']].append(p)

x265_by=defaultdict(list)
for p in x265_points: x265_by[p['name']].append(p)

report={}
for name, vpts in vae_by.items():
    cat=categories.get(name,'unknown')
    if x264_by.get(name):
        bd1=bdrate([p['rate_bpp'] for p in vpts],[p['psnr'] for p in vpts], [p['rate_bpp'] for p in x264_by[name]], [p['psnr'] for p in x264_by[name]])
        report.setdefault(cat,{}).setdefault('BD-Rate vs H.264 (%)',[]).append(bd1)
    if x265_by.get(name):
        bd2=bdrate([p['rate_bpp'] for p in vpts],[p['psnr'] for p in vpts], [p['rate_bpp'] for p in x265_by[name]], [p['psnr'] for p in x265_by[name]])
        report.setdefault(cat,{}).setdefault('BD-Rate vs H.265 (%)',[]).append(bd2)

# average per category
for cat,d in report.items():
    for k,arr in list(d.items()):
        arr=[a for a in arr if a is not None]
        d[k]=float(np.mean(arr)) if arr else None

print(json.dumps(report, indent=2))


## 13) Save metrics

In [None]:
os.makedirs('runs/notebook', exist_ok=True)
import json
with open('runs/notebook/vae_points.json','w') as f: json.dump(vae_points,f,indent=2)
with open('runs/notebook/x264_points.json','w') as f: json.dump(x264_points,f,indent=2)
with open('runs/notebook/x265_points.json','w') as f: json.dump(x265_points,f,indent=2)
print('Saved to runs/notebook/*.json')
