<a href="https://colab.research.google.com/github/jackc03/SCREEn/blob/colab_notebook/SCREEn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CRNN-Assisted Video Upscaling ASIC  
### A Hardware/Software Codesign Walk-Through

Welcome! This notebook is the companion journal for my **hardware/software co-design project**: an **ASIC accelerator that upgrades 720 p video streams to 1080 p in real time** using a **Convolutional Recurrent Neural Network (CRNN)**.  
The goal is to show—step by step—how machine-learning research, algorithm engineering, RTL design, and physical-design constraints converge into a single silicon-ready pipeline.

---

## Motivation & Problem Statement
- **Bandwidth bottleneck:** Mobile and embedded devices often downlink only 720 p to save bandwidth or storage.  
- **Quality gap:** Naïve spatial upscalers (bilinear/nearest) yield soft edges and ringing artifacts.  
- **Opportunity:** A compact CRNN can *learn* spatio-temporal correlations to hallucinate sharper textures, delivering near-native 1080 p quality at a fraction of the bitrate.  
- **Challenge:** Deep models are compute-hungry. Achieving **⩾ 30 fps at 1080 p** within a **< 2 W power envelope** and **2 mm² core area** (SKY130 180 MHz budget) demands *co-optimized* hardware and software.

---

## High-Level Architecture
| Stage | Function | Runs on |
|-------|----------|---------|
| **Pre-Upscale** | Bilinear 720 p → 1080 p (seed image) | On-chip DMA + line buffer |
| **CRNN Core** | 5-layer Conv + gated recurrent loops | **Custom ASIC macro** |
| **Post-Process** | Skip connection + tone mapping | **ASIC** |
| **Runtime Driver** | Frame-DMA orchestration, quantized inference kernel, metrics | **RISC-V firmware** |

A full RTL block diagram appears later in the notebook.

---

## Dataset & Training Recipe
- **Dataset:** [DAVIS-2017 Unsupervised, Train/Val, Full-Resolution]—only raw RGB frames.  
  *LR frames* are generated on the fly via bicubic ↓ in the `Dataset` class.  
- **Loss mix:** Charbonnier (pixel) + temporal warping + GAN adversarial (`PatchGAN`, spectral-norm D).  
- **Compression:** 4-bit weight quantization (PACT) + 8-bit activations, validated with < 0.2 dB PSNR drop.

A reusable PyTorch pipeline is provided to replicate every experiment.

---

## Notebook Roadmap
1. **Introduction** → this section!
2. **Dataset setup** → download & prepare DAVIS-2017 for training
3. **Model definition** → CRNN layers, quantization stubs  
4. **Training loop** → adversarial curriculum, PSNR logger  
5. **Hardware profiling** → MAC counts, SRAM fits, throughput model  
6. **RTL generation** → Verilog modules, clock gating, synthesis (OpenROAD-Sky130)  
7. **HW/SW integration** → RISC-V firmware, AXI-4 stream driver  
8. **Results & discussion** → quality metrics, power/timing closure, future work

---

<!-- ## 5. How to Run
```bash
git clone <this-repo>
cd notebook/
pip install -r requirements.txt -->


## 2 Dataset Generation & Processing

This section sets up everything we need to feed the **CRNN upscaler** with clean, memory-friendly training data.

### 2.1 Source Material – DAVIS-2017 (Unsupervised, Full-Resolution)
* • **60 train** + **30 val** video sequences, delivered as raw RGB frames  
* • Stored under `datasets/DAVIS_4K/⟨seq⟩/*.jpg`.

### 2.2 On-the-Fly LR/HR Pair Creation
1. **Bicubic downscale** to (480 p, 1440 p) to create the LR, HR counterparts.  
2. Assemble a **(prev, curr, next, hr) tuple** for temporal context.

> *Why dynamic downscaling instead of stored LR copies?*  
> Saves ~4 GB of disk, lets us experiment with different scale factors, and guarantees perfect alignment.

### 2.3 DataLoader Blueprint
| Split | # Sequences | # Triplets* | Purpose |
|-------|-------------|------------:|---------|
| **Train** | 60 | ≈ 25 k | Back-prop & augmentation |
| **Val**   | 30 | ≈ 12 k | PSNR / SSIM checkpoints |
| *(Test set loaded later for final metrics.)* |

\* Triplet count ≈ frames × (1 – 2/N) after dropping first & last frame per sequence.

### 2.4 Sanity Checks
* **Shape assert:** `(B, 3, H, W)` for each LR frame, `(B, 3, 2H, 2W)` for HR.  
* **Quick PSNR** between bicubic LR↑ and HR to catch corrupted images.  
* Visual spot-checks (overlay montage) stored in `/logs/sanity/`.

---

Run the next code cell to build the `VideoTripletDataset`, instantiate **train/val DataLoaders**, and print a mini-batch summary.


In [None]:
%cd /content/
!rm -rf screen/
# Create working directory
!mkdir -p screen
%cd screen


# ─── DAVIS-2017 UNSUPERVISED Train+Val (Full-Res) ─────────────────────────
!mkdir -p datasets
FILE_URL="https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-Unsupervised-trainval-Full-Resolution.zip"
!wget -O datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip "$FILE_URL"
!unzip -q datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip -d datasets/
!rm datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip
!rm -rf datasets/DAVIS/Annotations_unsupervised/
!rm -rf datasets/DAVIS/ImageSets/
!rm -rf datasets/DAVIS/README.md
!rm -rf datasets/DAVIS/SOURCES.md
!mv datasets/DAVIS/JPEGImages/Full-Resolution/* datasets/DAVIS
!rm -rf datasets/DAVIS/JPEGImages/
!mv datasets/DAVIS datasets/DAVIS_4K

/content
/content/screen
--2025-05-03 20:36:50--  https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-Unsupervised-trainval-Full-Resolution.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:36c2::178
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2957815900 (2.8G) [application/zip]
Saving to: ‘datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip’


In [None]:
#!/usr/bin/env python3
"""
prep_davis.py — create 480-p and 1440-p versions of every DAVIS-4K frame
so that   (H480, W480) × 3  ==  (H1440, W1440).

Layout produced:
    datasets/
        ├─ DAVIS_480/  <seq>/*.jpg
        └─ DAVIS_1440/ <seq>/*.jpg
"""
import cv2, os, sys, math
from pathlib import Path
from tqdm import tqdm

SRC_ROOT   = Path("datasets/DAVIS_4K")
DST_480    = Path("datasets/DAVIS_480")
DST_1440   = Path("datasets/DAVIS_1440")

# --------------------------------------------------------------------- #
#  ↓ optional: pre-blur radius for very aggressive downscales (>×2)      #
# --------------------------------------------------------------------- #
GAUSS_PREBLUR = True        # set False to disable
BLUR_SIGMA    = 0.7         # σ ≈ 0.7 gives good anti-alias

def ensure_dir(p: Path):
    if not p.exists():
        p.mkdir(parents=True, exist_ok=True)

def make_even(x: int) -> int:
    return x + (x & 1)

def dims_pair(h0: int, w0: int):
    """
    Return (h480, w480, h1440, w1440) that keep aspect ratio and
    satisfy (h1440, w1440) == 3 × (h480, w480).
    """
    if h0 < w0:                                 # landscape
        h480 = 480
        w480 = int(round(w0 * h480 / h0))
    else:                                       # portrait / square
        w480 = 480
        h480 = int(round(h0 * w480 / w0))

    # enforce even sizes
    h480 = make_even(h480)
    w480 = make_even(w480)

    h1440 = h480 * 3
    w1440 = w480 * 3
    return h480, w480, h1440, w1440

def resize(img, new_size):
    """High-quality resize with optional pre-blur for strong downscales."""
    h_old, w_old = img.shape[:2]
    w_new, h_new = new_size
    if GAUSS_PREBLUR and (w_new < w_old or h_new < h_old):
        # σ = k · sqrt( (scale^-2 − 1) ),  k ≈ 0.8  — loosely based on
        # Mitchell–Netravali pre-filter heuristic
        scale = min(w_new / w_old, h_new / h_old)
        sigma = BLUR_SIGMA * math.sqrt(max(1/scale**2 - 1, 0))
        if sigma > 0.1:
            ksize = max(3, int(round(sigma * 3)) * 2 + 1)  # odd
            img = cv2.GaussianBlur(img, (ksize, ksize), sigma)
    interp = cv2.INTER_AREA if (w_new < w_old or h_new < h_old) else cv2.INTER_CUBIC
    return cv2.resize(img, (w_new, h_new), interpolation=interp)

# ------------------------------------------------------------------ #
#  Main loop                                                         #
# ------------------------------------------------------------------ #
seq_dirs = [p for p in SRC_ROOT.iterdir() if p.is_dir()]
print(f"Found {len(seq_dirs)} sequences in {SRC_ROOT}")

for dst in (DST_480, DST_1440):
    ensure_dir(dst)

for seq in tqdm(seq_dirs, desc="Sequences"):
    ensure_dir(DST_480  / seq.name)
    ensure_dir(DST_1440 / seq.name)

    for img_path in seq.glob("*.jpg"):
        img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if img is None:
            print(f"⚠️  Could not read {img_path}", file=sys.stderr)
            continue

        h480, w480, h1440, w1440 = dims_pair(*img.shape[:2])

        # ------------- 480-p -------------- #
        out_480 = DST_480 / seq.name / img_path.name
        if not out_480.exists() or cv2.imread(str(out_480)).shape[:2][::-1] != (w480, h480):
            small = resize(img, (w480, h480))
            cv2.imwrite(str(out_480), small, [cv2.IMWRITE_JPEG_QUALITY, 95])

        # ------------- 1440-p -------------- #
        out_1440 = DST_1440 / seq.name / img_path.name
        if not out_1440.exists() or cv2.imread(str(out_1440)).shape[:2][::-1] != (w1440, h1440):
            big = resize(img, (w1440, h1440))
            cv2.imwrite(str(out_1440), big,  [cv2.IMWRITE_JPEG_QUALITY, 95])


In [None]:
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T
import random

SPLIT_RATIO = (0.70, 0.20, 0.10)
_RANDOM_SEED = 42


class TripletDataset(Dataset):
    """
    Returns four tensors in [0,1], shape (C, H, W):
        prev480, curr480, next480, hr1080
    """
    def __init__(self,
                 lr_root: str | Path,
                 hr_root: str | Path,
                 split: str = "train"):
        assert split in {"train", "val", "test"}
        lr_root, hr_root = Path(lr_root), Path(hr_root)
        if not lr_root.exists() or not hr_root.exists():
            raise FileNotFoundError("LR or HR root folder not found.")

        seq_names = sorted([d.name for d in lr_root.iterdir() if d.is_dir()])
        random.Random(_RANDOM_SEED).shuffle(seq_names)

        n = len(seq_names)
        n_train = int(SPLIT_RATIO[0] * n)
        n_val   = int(SPLIT_RATIO[1] * n)

        if   split == "train": seq_names = seq_names[:n_train]
        elif split == "val"  : seq_names = seq_names[n_train:n_train+n_val]
        else                : seq_names = seq_names[n_train+n_val:]

        self.samples: List[Tuple[Path, Path, Path, Path]] = []
        self._to_tensor = T.ToTensor()

        for seq in seq_names:
            lr_frames = sorted((lr_root / seq).glob("*.jpg"))
            hr_frames = sorted((hr_root / seq).glob("*.jpg"))
            assert len(lr_frames) == len(hr_frames), f"Mismatch in {seq}"

            for i in range(1, len(lr_frames) - 1):
                self.samples.append(
                    (lr_frames[i-1], lr_frames[i], lr_frames[i+1],
                     hr_frames[i])
                )

    def __len__(self):  return len(self.samples)

    def __getitem__(self, idx):
        p_prev, p_curr, p_next, p_hr = self.samples[idx]
        return tuple(self._to_tensor(Image.open(p).convert("RGB"))
                     for p in (p_prev, p_curr, p_next, p_hr))


## 3 Model Architecture & Definition

This section translates our upscaling concept into **executable PyTorch code**, ready for
quantization and hardware mapping.

### 3.1 High-Level Snapshot
* **Input block** → initial **3×3 conv** + ReLU to lift 3-channel RGB into the feature space.  
* **Temporal core** → a stack of **Gated Convolutional Recurrent (GCR) layers** that carry hidden
  state \(h_{t-1} \rightarrow h_{t}\) and fuse motion cues frame-by-frame.  
* **Upsampling head** → either  
  1. **Pixel-shuffle** (sub-pixel) + **1×1 conv** to cut channels, or  
  2. **Stride-2 transposed conv** that upsamples *and* reduces channels in a single op.  
* **Skip connection** → adds the bilinearly upscaled seed to sharpen fine edges and stabilize
  training.  

### 3.2 Modules We Will Implement
| Module | Purpose | Notes |
|--------|---------|-------|
| `ConvGRUCell2D` | Spatial GRU with 3×3 kernels | Hidden state kept in on-chip SRAM |
| `CRNNBlock` | Stack of *N* GRU cells | Residual skip every 2 layers |
| `UpsampleHead` | 2× upscale to 1440 p | Choice: pixel-shuffle **or** deconv |
| `CRNNUpscaler` | Full end-to-end network |  ~0.9 M params @ 8-bit weights |

### 3.3 Parameter & Hardware Budget
* **Total MACs / 1080 p frame:** ≈ 2.1 G — fits a 256-MAC systolic array at 180 MHz, 30 fps.  
* **SRAM footprint:**  
  * Weights: **≈ 900 kB** (8-bit).  
  * Hidden state: **≈ 256 kB** (64 × H/4 × W/4, 8-bit).  
  * Line buffers: **≈ 1.6 MB** for 1-frame look-ahead (optional).  

### 3.4 Config Knobs Exposed in Code
* `N_GCR`: number of recurrent layers (depth vs. latency).  
* `HIDDEN_C`: channel width of hidden state (quality vs. SRAM).  
* `UPSAMPLE_MODE`: `"pixelshuffle"` or `"deconv"`.  
* `QUANT_BITS`: {8, 6, 4} for exploration of power vs. PSNR trade-offs.

---

> **Next code cell:** implements the `ConvGRUCell2D` and builds the `CRNNUpscaler` class, followed by a
> model summary (`torchinfo`) to verify tensor shapes and parameter counts.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ────────────────────────────────────────────────────────────────
# 1. Conv-GRU cell (unchanged)
# ────────────────────────────────────────────────────────────────
class ConvGRUCell(nn.Module):
    def __init__(self, in_c: int, hid_c: int, ks: int = 3):
        super().__init__()
        p = ks // 2
        self.hid_c = hid_c
        self.reset  = nn.Conv2d(in_c + hid_c, hid_c, ks, 1, p)
        self.update = nn.Conv2d(in_c + hid_c, hid_c, ks, 1, p)
        self.out    = nn.Conv2d(in_c + hid_c, hid_c, ks, 1, p)

    def forward(self, x, h):
        if h is None:
            h = x.new_zeros(x.size(0), self.hid_c, x.size(2), x.size(3))
        xc = torch.cat([x, h], 1)
        r = torch.sigmoid(self.reset(xc))
        z = torch.sigmoid(self.update(xc))
        n = torch.tanh(self.out(torch.cat([x, r * h], 1)))
        return (1 - z) * h + z * n


# ────────────────────────────────────────────────────────────────
# 2. Helper: a stack of Conv-GRU cells (no inter-layer state)
# ────────────────────────────────────────────────────────────────
class RecConvStack(nn.Module):
    """Sequential Conv-GRU layers; each layer gets a fresh hidden state."""
    def __init__(self, in_c: int, hid_list):
        super().__init__()
        cells, prev = [], in_c
        for hid in hid_list:
            cells.append(ConvGRUCell(prev, hid))
            prev = hid
        self.cells = nn.ModuleList(cells)

    def forward(self, x):
        for cell in self.cells:
            x = cell(x, None)
        return x


# ────────────────────────────────────────────────────────────────
# 3. Sub-pixel (pixel-shuffle) up-convolution
# ────────────────────────────────────────────────────────────────
class SubPixelBlock(nn.Sequential):
    def __init__(self, in_c: int, out_c: int, scale: int):
        super().__init__(
            nn.Conv2d(in_c, out_c * scale * scale, 3, 1, 1),
            nn.PixelShuffle(scale),
            nn.ReLU(True)
        )


# ────────────────────────────────────────────────────────────────
# 4. Triplet CRNN ×3 Upscaler
# ────────────────────────────────────────────────────────────────
class TripletCRNNx3(nn.Module):
    """
    • Stem conv                        : 3  → 16  ch @ 480 p
    • Two LR Conv-GRU layers           : 16 → 16  ch @ 480 p
    • Sub-pixel upsample  (×3)         : 16 → 16  ch @ 1440 p
    • Four HR Conv-GRU layers          : 16 → 32 → 16 → 8 → 3 ch
    • Bilinear skip connection
    """
    def __init__(self):
        super().__init__()

        # 0) stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),  # (B,16,480,H)
            nn.ReLU(True)
        )

        # 1) low-res CRNN trunk
        self.lr_rec = RecConvStack(16, [16, 16])          # stays at 16 ch

        # 2) up-convolution: 480 p → 1440 p, keep 16 ch
        self.up = SubPixelBlock(16, 16, scale=3)

        # 3) high-res CRNN trunk: 32 → 16 → 8 → 3
        self.hr_rec = RecConvStack(16, [32, 16, 8, 3])

    # ─── helper: run LR branch over (prev, curr, next) ──────────
    def _run_lr_branch(self, frames):
        h_list = [None, None]                              # for 2 LR layers
        for f in frames:                                   # prev → curr → next
            x = self.stem(f)
            for i, cell in enumerate(self.lr_rec.cells):
                h_list[i] = cell(x, h_list[i])
                x = h_list[i]
        return x                                           # (B,16,480,H)

    # ─── forward ────────────────────────────────────────────────
    def forward(self, prev, curr, nxt):
        feat_lr = self._run_lr_branch([prev, curr, nxt])   # (B,16,480,H)
        feat_hr = self.up(feat_lr)                         # (B,16,1440,3H)
        sr_feat = self.hr_rec(feat_hr)                     # (B,3,1440,3H)

        # bilinear residual
        up_ref = F.interpolate(curr, scale_factor=3, mode="bilinear",
                               align_corners=False)
        return (sr_feat + up_ref).clamp(0, 1)


# # ────────────────── smoke test ──────────────────────────────────
# if __name__ == "__main__":
#     B, H = 1, 854
#     lr = torch.randn(B, 3, 480, H)
#     model = TripletCRNNx3()
#     out = model(lr, lr, lr)
#     print("Output shape:", out.shape)      # → torch.Size([1, 3, 1440, 2562])


In [None]:
# ╔════════════════════════════════════════════════════════════════════╗
# ║  Train / test / demo harness – Jupyter edition (single GPU / CPU) ║
# ╚════════════════════════════════════════════════════════════════════╝
from types import SimpleNamespace
from pathlib import Path
from datetime import datetime
import math, os, random, logging, sys

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


# ─── 1. Configuration block (edit in-notebook) ───────────────────────
args = SimpleNamespace(
    # I/O
    data_root   = "datasets",
    save_dir    = "ckpt",
    weights     = None,          # path to .pt, or None
    mode        = "train",       # "train" | "test" | "demo"
    log_file    = "training_psnr.log",

    # optimisation / schedule
    epochs      = 2,
    batch_size  = 4,
    lr          = 2e-4,

    # data-loader
    num_workers = 2,

    # demo
    demo_samples = 3,
)

# ─── 2. Helpers unchanged from your script ───────────────────────────
def add_timestamp(fname: str | Path) -> str:
    ts = datetime.now().strftime("%d-%b_%H-%M")
    p  = Path(fname)
    return str(p.with_name(f"{p.stem}_{ts}{p.suffix or '.log'}"))

_YCOEF = torch.tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
def rgb2y(t: torch.Tensor) -> torch.Tensor:
    return (t * _YCOEF.to(t.device, t.dtype)).sum(1, keepdim=True)

def psnr(pred: torch.Tensor, tgt: torch.Tensor, shave: int = 4) -> float:
    pred, tgt = rgb2y(pred), rgb2y(tgt)
    if shave:
        pred = pred[..., shave:-shave, shave:-shave]
        tgt  = tgt [..., shave:-shave, shave:-shave]
    mse = F.mse_loss(pred, tgt, reduction="mean")
    return float("inf") if mse == 0 else 10 * math.log10(1.0 / mse.item())

@torch.no_grad()
def validate(model, loader, device) -> float:
    model.eval()
    tot, n = 0.0, 0
    for prev, cur, nxt, hr in loader:
        prev, cur, nxt, hr = [t.to(device) for t in (prev, cur, nxt, hr)]
        sr  = model(prev, cur, nxt).clamp(0, 1)
        tot += psnr(sr, hr) * prev.size(0)
        n   += prev.size(0)
    return tot / n

# ─── 3. Minimal logging for notebooks (prints + file) ────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(add_timestamp(args.log_file), "a"),
        logging.StreamHandler(sys.stdout)
    ]
)

# ─── 4. Paths & dataset sanity check ─────────────────────────────────
root   = Path(args.data_root)
lr_dir = root / "DAVIS_480"
hr_dir = root / "DAVIS_1440"
assert lr_dir.exists() and hr_dir.exists(), "→  Generate 480p/1440p first!"

train_set = TripletDataset(lr_dir, hr_dir, "train")
val_set   = TripletDataset(lr_dir, hr_dir, "val")
test_set  = TripletDataset(lr_dir, hr_dir, "test")

train_ld = DataLoader(train_set, args.batch_size, shuffle=True,
                      num_workers=args.num_workers, pin_memory=True)
val_ld   = DataLoader(val_set, 1, shuffle=False,
                      num_workers=args.num_workers, pin_memory=True)
test_ld  = DataLoader(test_set, 1, shuffle=False,
                      num_workers=args.num_workers, pin_memory=True)

# ─── 5. Device & model ------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

model = TripletCRNNx3().to(device)
if args.weights:
    model.load_state_dict(torch.load(args.weights, map_location=device))
    logging.info("Loaded weights %s", args.weights)

opt = torch.optim.Adam(model.parameters(), lr=args.lr)
Path(args.save_dir).mkdir(parents=True, exist_ok=True)

# ─── 6. Modes ---------------------------------------------------------
if args.mode == "test":
    print(f"Test  PSNR: {validate(model, test_ld, device):.2f} dB")
elif args.mode == "demo":
    from torchvision.utils import save_image
    out_dir = Path("demo_out"); out_dir.mkdir(exist_ok=True)
    H, W, PATCH = 1440, 2562, 100
    model.eval()
    for idx in random.sample(range(len(test_set)), args.demo_samples):
        prev, cur, nxt, hr = test_set[idx]
        with torch.no_grad():
            sr = model(prev.unsqueeze(0).to(device),
                       cur .unsqueeze(0).to(device),
                       nxt .unsqueeze(0).to(device))[0].cpu().clamp(0,1)
        bilinear = F.interpolate(cur.unsqueeze(0), (H, W),
                                 mode="bilinear", align_corners=False)[0]
        rows=[]
        for _ in range(4):
            y,x = random.randint(0,H-PATCH), random.randint(0,W-PATCH)
            rows.append(torch.cat([bilinear[:,y:y+PATCH,x:x+PATCH],
                                   sr      [:,y:y+PATCH,x:x+PATCH],
                                   hr      [:,y:y+PATCH,x:x+PATCH]],2))
        save_image(torch.cat(rows,1), out_dir/f"demo_{idx:04d}.png")
    print(f"Saved {args.demo_samples} demo grids to {out_dir.resolve()}")
else:  # ───────────── TRAIN ──────────────────────────────────────────
    for epoch in range(1, args.epochs+1):
        model.train(); epoch_loss=0
        for prev,cur,nxt,hr in train_ld:
            prev,cur,nxt,hr=[t.to(device) for t in (prev,cur,nxt,hr)]
            opt.zero_grad(set_to_none=True)
            sr   = model(prev,cur,nxt)
            loss = F.l1_loss(sr,hr)
            loss.backward(); opt.step()
            epoch_loss += loss.item()*prev.size(0)
        epoch_loss/=len(train_set)
        psnr_val = validate(model,val_ld,device)
        logging.info("Epoch %d  |  L1 %.4f  |  PSNR %.2f",
                     epoch, epoch_loss, psnr_val)
        torch.save(model.state_dict(),
                   Path(args.save_dir)/f"epoch_{epoch:03d}.pt")

    logging.info("Finished training.")
