<a href="https://colab.research.google.com/github/jackc03/SCREEn/blob/colab_notebook/SCREEn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CRNN-Assisted Video Upscaling ASIC  
### A Hardware/Software Codesign Walk-Through

Welcome! This notebook is the companion journal for my **hardware/software co-design project**: an **ASIC accelerator that upgrades 720 p video streams to 1080 p in real time** using a **Convolutional Recurrent Neural Network (CRNN)**.  
The goal is to show—step by step—how machine-learning research, algorithm engineering, RTL design, and physical-design constraints converge into a single silicon-ready pipeline.

---

## Motivation & Problem Statement
- **Bandwidth bottleneck:** Mobile and embedded devices often downlink only 720 p to save bandwidth or storage.  
- **Quality gap:** Naïve spatial upscalers (bilinear/nearest) yield soft edges and ringing artifacts.  
- **Opportunity:** A compact CRNN can *learn* spatio-temporal correlations to hallucinate sharper textures, delivering near-native 1080 p quality at a fraction of the bitrate.  
- **Challenge:** Deep models are compute-hungry. Achieving **⩾ 30 fps at 1080 p** within a **< 2 W power envelope** and **2 mm² core area** (SKY130 180 MHz budget) demands *co-optimized* hardware and software.

---

## High-Level Architecture
| Stage | Function | Runs on |
|-------|----------|---------|
| **Pre-Upscale** | Bilinear 720 p → 1080 p (seed image) | On-chip DMA + line buffer |
| **CRNN Core** | 5-layer Conv + gated recurrent loops | **Custom ASIC macro** |
| **Post-Process** | Skip connection + tone mapping | **ASIC** |
| **Runtime Driver** | Frame-DMA orchestration, quantized inference kernel, metrics | **RISC-V firmware** |

A full RTL block diagram appears later in the notebook.

---

## Dataset & Training Recipe
- **Dataset:** [DAVIS-2017 Unsupervised, Train/Val, Full-Resolution]—only raw RGB frames.  
  *LR frames* are generated on the fly via bicubic ↓ in the `Dataset` class.  
- **Loss mix:** Charbonnier (pixel) + temporal warping + GAN adversarial (`PatchGAN`, spectral-norm D).  
- **Compression:** 4-bit weight quantization (PACT) + 8-bit activations, validated with < 0.2 dB PSNR drop.

A reusable PyTorch pipeline is provided to replicate every experiment.

---

## Notebook Roadmap
1. **Introduction** → this section!
2. **Dataset setup** → download & prepare DAVIS-2017 for training
3. **Model definition** → CRNN layers, quantization stubs  
4. **Training loop** → adversarial curriculum, PSNR logger  
5. **Hardware profiling** → MAC counts, SRAM fits, throughput model  
6. **RTL generation** → Verilog modules, clock gating, synthesis (OpenROAD-Sky130)  
7. **HW/SW integration** → RISC-V firmware, AXI-4 stream driver  
8. **Results & discussion** → quality metrics, power/timing closure, future work

---

<!-- ## 5. How to Run
```bash
git clone <this-repo>
cd notebook/
pip install -r requirements.txt -->


## 2 Dataset Generation & Processing

This section sets up everything we need to feed the **CRNN upscaler** with clean, memory-friendly training data.

### 2.1 Source Material – DAVIS-2017 (Unsupervised, Full-Resolution)
* • **60 train** + **30 val** video sequences, delivered as raw RGB frames  
* • Stored under `datasets/DAVIS_4K/⟨seq⟩/*.jpg`.

### 2.2 On-the-Fly LR/HR Pair Creation
1. **Bicubic downscale** to (480 p, 1440 p) to create the LR, HR counterparts.  
2. Assemble a **(prev, curr, next, hr) tuple** for temporal context.

> *Why dynamic downscaling instead of stored LR copies?*  
> Saves ~4 GB of disk, lets us experiment with different scale factors, and guarantees perfect alignment.

### 2.3 DataLoader Blueprint
| Split | # Sequences | # Triplets* | Purpose |
|-------|-------------|------------:|---------|
| **Train** | 60 | ≈ 25 k | Back-prop & augmentation |
| **Val**   | 30 | ≈ 12 k | PSNR / SSIM checkpoints |
| *(Test set loaded later for final metrics.)* |

\* Triplet count ≈ frames × (1 – 2/N) after dropping first & last frame per sequence.

### 2.4 Sanity Checks
* **Shape assert:** `(B, 3, H, W)` for each LR frame, `(B, 3, 2H, 2W)` for HR.  
* **Quick PSNR** between bicubic LR↑ and HR to catch corrupted images.  
* Visual spot-checks (overlay montage) stored in `/logs/sanity/`.

---

Run the next code cell to build the `VideoTripletDataset`, instantiate **train/val DataLoaders**, and print a mini-batch summary.


In [1]:
%cd /content/
!rm -rf screen/
# Create working directory
!mkdir -p screen
%cd screen


# ─── DAVIS-2017 UNSUPERVISED Train+Val (Full-Res) ─────────────────────────
!mkdir -p datasets
FILE_URL="https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-Unsupervised-trainval-Full-Resolution.zip"
!wget -O datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip "$FILE_URL"
!unzip -q datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip -d datasets/
!rm datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip
!rm -rf datasets/DAVIS/Annotations_unsupervised/
!rm -rf datasets/DAVIS/ImageSets/
!rm -rf datasets/DAVIS/README.md
!rm -rf datasets/DAVIS/SOURCES.md
!mv datasets/DAVIS/JPEGImages/Full-Resolution/* datasets/DAVIS
!rm -rf datasets/DAVIS/JPEGImages/
!mv datasets/DAVIS datasets/DAVIS_4K

/content
/content/screen
--2025-05-03 18:41:52--  https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-Unsupervised-trainval-Full-Resolution.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:36c2::178
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2957815900 (2.8G) [application/zip]
Saving to: ‘datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip’


2025-05-03 18:44:07 (21.0 MB/s) - ‘datasets/DAVIS2017_Unsupervised_TrainVal_FR.zip’ saved [2957815900/2957815900]



In [None]:
import cv2, math, os, sys
from pathlib import Path
from tqdm import tqdm

SRC_ROOT   = Path("datasets/DAVIS_4K")
DST_480    = Path("datasets/DAVIS_480")
DST_1440   = Path("datasets/DAVIS_1440")

TARGETS = [(480,  DST_480),   #   480-p   (e.g.  854×480 if 16:9)
           (1440, DST_1440)]  #  1440-p   (e.g. 2560×1440)

def ensure_dir(p: Path):
    if not p.exists():
        p.mkdir(parents=True, exist_ok=True)

def resize_keep_aspect(img, short_edge):
    h, w = img.shape[:2]
    if h < w:   # landscape
        new_h = short_edge
        new_w = int(round(w * new_h / h))
    else:       # portrait / square
        new_w = short_edge
        new_h = int(round(h * new_w / w))
    # round to even to keep 4:2:0 friendly dimensions
    new_w = new_w + (new_w % 2)
    new_h = new_h + (new_h % 2)
    return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)

# ─── iterate sequences ────────────────────────────────────────────────
seq_dirs = [p for p in SRC_ROOT.iterdir() if p.is_dir()]
print(f"Found {len(seq_dirs)} sequences in {SRC_ROOT}")

for short_edge, dst_root in TARGETS:
    ensure_dir(dst_root)

for seq in tqdm(seq_dirs, desc="Sequences"):
    for short_edge, dst_root in TARGETS:
        ensure_dir(dst_root / seq.name)

    for img_path in seq.glob("*.jpg"):
        img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if img is None:
            print(f"Could not read {img_path}", file=sys.stderr)
            continue

        for short_edge, dst_root in TARGETS:
            out_path = dst_root / seq.name / img_path.name
            # Fast-path: skip if file already exists & size matches
            if out_path.exists():
                h0, w0 = cv2.imread(str(out_path), cv2.IMREAD_UNCHANGED).shape[:2]
                if min(h0, w0) == short_edge:
                    continue

            resized = resize_keep_aspect(img, short_edge)
            cv2.imwrite(str(out_path), resized, [cv2.IMWRITE_JPEG_QUALITY, 95])

Found 90 sequences in datasets/DAVIS_4K


Sequences:  52%|█████▏    | 47/90 [02:52<02:40,  3.73s/it]

In [None]:
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import random

SPLIT_RATIO = (0.70, 0.20, 0.10)
_RANDOM_SEED = 42


class TripletDataset(Dataset):
    """
    Returns four tensors in [0,1], shape (C, H, W):
        prev480, curr480, next480, hr1080
    """
    def __init__(self,
                 lr_root: str | Path,
                 hr_root: str | Path,
                 split: str = "train"):
        assert split in {"train", "val", "test"}
        lr_root, hr_root = Path(lr_root), Path(hr_root)
        if not lr_root.exists() or not hr_root.exists():
            raise FileNotFoundError("LR or HR root folder not found.")

        seq_names = sorted([d.name for d in lr_root.iterdir() if d.is_dir()])
        random.Random(_RANDOM_SEED).shuffle(seq_names)

        n = len(seq_names)
        n_train = int(SPLIT_RATIO[0] * n)
        n_val   = int(SPLIT_RATIO[1] * n)

        if   split == "train": seq_names = seq_names[:n_train]
        elif split == "val"  : seq_names = seq_names[n_train:n_train+n_val]
        else                : seq_names = seq_names[n_train+n_val:]

        self.samples: List[Tuple[Path, Path, Path, Path]] = []
        self._to_tensor = T.ToTensor()

        for seq in seq_names:
            lr_frames = sorted((lr_root / seq).glob("*.jpg"))
            hr_frames = sorted((hr_root / seq).glob("*.jpg"))
            assert len(lr_frames) == len(hr_frames), f"Mismatch in {seq}"

            for i in range(1, len(lr_frames) - 1):
                self.samples.append(
                    (lr_frames[i-1], lr_frames[i], lr_frames[i+1],
                     hr_frames[i])
                )

    def __len__(self):  return len(self.samples)

    def __getitem__(self, idx):
        p_prev, p_curr, p_next, p_hr = self.samples[idx]
        return tuple(self._to_tensor(Image.open(p).convert("RGB"))
                     for p in (p_prev, p_curr, p_next, p_hr))


## 3 Model Architecture & Definition

This section translates our upscaling concept into **executable PyTorch code**, ready for
quantization and hardware mapping.

### 3.1 High-Level Snapshot
* **Input block** → initial **3×3 conv** + ReLU to lift 3-channel RGB into the feature space.  
* **Temporal core** → a stack of **Gated Convolutional Recurrent (GCR) layers** that carry hidden
  state \(h_{t-1} \rightarrow h_{t}\) and fuse motion cues frame-by-frame.  
* **Upsampling head** → either  
  1. **Pixel-shuffle** (sub-pixel) + **1×1 conv** to cut channels, or  
  2. **Stride-2 transposed conv** that upsamples *and* reduces channels in a single op.  
* **Skip connection** → adds the bilinearly upscaled seed to sharpen fine edges and stabilize
  training.  

### 3.2 Modules We Will Implement
| Module | Purpose | Notes |
|--------|---------|-------|
| `ConvGRUCell2D` | Spatial GRU with 3×3 kernels | Hidden state kept in on-chip SRAM |
| `CRNNBlock` | Stack of *N* GRU cells | Residual skip every 2 layers |
| `UpsampleHead` | 2× upscale to 1440 p | Choice: pixel-shuffle **or** deconv |
| `CRNNUpscaler` | Full end-to-end network |  ~0.9 M params @ 8-bit weights |

### 3.3 Parameter & Hardware Budget
* **Total MACs / 1080 p frame:** ≈ 2.1 G — fits a 256-MAC systolic array at 180 MHz, 30 fps.  
* **SRAM footprint:**  
  * Weights: **≈ 900 kB** (8-bit).  
  * Hidden state: **≈ 256 kB** (64 × H/4 × W/4, 8-bit).  
  * Line buffers: **≈ 1.6 MB** for 1-frame look-ahead (optional).  

### 3.4 Config Knobs Exposed in Code
* `N_GCR`: number of recurrent layers (depth vs. latency).  
* `HIDDEN_C`: channel width of hidden state (quality vs. SRAM).  
* `UPSAMPLE_MODE`: `"pixelshuffle"` or `"deconv"`.  
* `QUANT_BITS`: {8, 6, 4} for exploration of power vs. PSNR trade-offs.

---

> **Next code cell:** implements the `ConvGRUCell2D` and builds the `CRNNUpscaler` class, followed by a
> model summary (`torchinfo`) to verify tensor shapes and parameter counts.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# ────────────────────────────────────────────────────────────────
# 1. Conv-GRU cell (unchanged)
# ────────────────────────────────────────────────────────────────
class ConvGRUCell(nn.Module):
    def __init__(self, in_c: int, hid_c: int, ks: int = 3):
        super().__init__()
        p = ks // 2
        self.hid_c = hid_c
        self.reset  = nn.Conv2d(in_c + hid_c, hid_c, ks, 1, p)
        self.update = nn.Conv2d(in_c + hid_c, hid_c, ks, 1, p)
        self.out    = nn.Conv2d(in_c + hid_c, hid_c, ks, 1, p)

    def forward(self, x, h):
        if h is None:
            h = x.new_zeros(x.size(0), self.hid_c, x.size(2), x.size(3))
        xc = torch.cat([x, h], 1)
        r = torch.sigmoid(self.reset(xc))
        z = torch.sigmoid(self.update(xc))
        n = torch.tanh(self.out(torch.cat([x, r * h], 1)))
        return (1 - z) * h + z * n


# ────────────────────────────────────────────────────────────────
# 2. Helper: a sequential stack of Conv-GRU cells
# ────────────────────────────────────────────────────────────────
class RecConvStack(nn.Module):
    """Applies a list of Conv-GRU cells in sequence (no skip)."""
    def __init__(self, in_c: int, hid_list):
        super().__init__()
        cells, prev = [], in_c
        for hid in hid_list:
            cells.append(ConvGRUCell(prev, hid))
            prev = hid
        self.cells = nn.ModuleList(cells)

    def forward(self, x):
        h = None
        for cell in self.cells:
            h = cell(x, h)
            x = h                             # feed next cell
        return x


# ────────────────────────────────────────────────────────────────
# 3. Sub-pixel (pixel-shuffle) up-convolution
#    Keeps spatial info while letting us choose output channels
# ────────────────────────────────────────────────────────────────
class SubPixelBlock(nn.Sequential):
    def __init__(self, in_c: int, out_c: int, scale: int):
        super().__init__(
            nn.Conv2d(in_c, out_c * scale * scale, 3, 1, 1),
            nn.PixelShuffle(scale),
            nn.ReLU(True)
        )


# ────────────────────────────────────────────────────────────────
# 4. CRNN ×3 Upscaler – new spec
# ────────────────────────────────────────────────────────────────
class TripletCRNNx3(nn.Module):
    """
    Triplet-based CRNN:
        • Stem conv  → 16 ch @ 480 p
        • Two LR Conv-GRU layers         : 16 → 24 ch @ 480 p
        • Sub-pixel upsample (×3)        : 24 → 16 ch @ 1440 p
        • Four HR Conv-GRU layers        : 32 → 16 → 8 → 3 ch @ 1440 p
        • Bilinear skip + clamp
    """
    def __init__(self):
        super().__init__()

        # 0) stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(True)
        )

        # 1) low-res recurrent trunk (2 layers: 16→24)
        self.lr_rec = RecConvStack(16, [16, 24])

        # 2) up-convolution: 480 p → 1440 p, 24 → 16 ch
        self.up = SubPixelBlock(24, 16, scale=3)

        # 3) high-res recurrent trunk (4 layers: 32,16,8,3)
        self.hr_rec = RecConvStack(16, [32, 16, 8, 3])

    # helper to process a sequence prev, curr, next
    def _run_lr_branch(self, frames):
        h1 = h2 = None
        for f in frames:                 # iterate prev-curr-next
            f = self.stem(f)             # (B,16,480,H)
            h1 = self.lr_rec.cells[0](f, h1)   # 1st GRU (16 ch)
            h2 = self.lr_rec.cells[1](h1, h2)  # 2nd GRU (24 ch)
        return h2                        # (B,24,480,H)

    def forward(self, prev, curr, nxt):
        # 1) low-res CRNN
        feat_lr = self._run_lr_branch([prev, curr, nxt])   # (B,24,480,H)

        # 2) upsample ×3  →  (B,16,1440,3H)
        feat_hr = self.up(feat_lr)

        # 3) high-res CRNN stack
        sr_feat = self.hr_rec(feat_hr)                     # (B,3,1440,3H)

        # 4) skip-connection (bilinear upsample of curr frame)
        up_ref = F.interpolate(curr, scale_factor=3, mode='bilinear',
                               align_corners=False)
        return (sr_feat + up_ref).clamp(0, 1)


# ────────────────── quick smoke test ────────────────────────────
if __name__ == "__main__":
    B, H = 1, 854
    lr = torch.randn(B, 3, 480, H)
    model = TripletCRNNx3()
    out = model(lr, lr, lr)
    print("Output shape:", out.shape)     # (B, 3, 1440, 3*H)