# IIR2D Explainer Notebook: Fast Filters, Weird Art, Real Utility

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/fyremael/iir2d/blob/main/docs/notebooks/IIR2D_Explainer_Colab.ipynb)

This notebook teaches **what IIR2D is**, **how to use all 8 filters**, and **why it matters** for image/video pipelines.

We'll use playful exemplars (cosmic portrait, mountain scroll, microbe swarm), then map them to production intuition.

## What You'll Do

1. Boot the environment (Colab-friendly).
2. Auto-bootstrap GPU execution for IIR2D if a CUDA GPU is available.
3. Generate entertaining synthetic scenes.
4. Run filters `1..8` using the canonical IIR2D CPU reference.
5. Compare border modes and precisions.
6. Run a mini temporal/video-style demo.
7. Run an actual JAX+IIR2D GPU demo when available.

In [None]:
# Colab + local setup
import os
import sys
import subprocess
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    repo = Path("/content/iir2d")
    if not repo.exists():
        subprocess.check_call(["git", "clone", "https://github.com/fyremael/iir2d.git", str(repo)])
    os.chdir(repo)

root = Path.cwd().resolve()
if not (root / "scripts").exists():
    candidates = [root.parent, root.parent.parent, root / "iir2d_op"]
    for cand in candidates:
        if cand.exists() and (cand / "scripts").exists():
            root = cand
            break

os.chdir(root)
if str(root) not in sys.path:
    sys.path.insert(0, str(root))
if str(root / "python") not in sys.path:
    sys.path.insert(0, str(root / "python"))

for pkg in ["numpy", "matplotlib"]:
    try:
        __import__(pkg)
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

print(f"ROOT={root}")

## GPU Bootstrap (Automatic)

This cell detects CUDA GPU availability and, when present, prepares the JAX + `iir2d_jax` custom-call path so the demo runs on GPU by default.

In [None]:
import shutil
import importlib
from pathlib import Path


def run_cmd(cmd):
    proc = subprocess.run(cmd, capture_output=True, text=True)
    return proc.returncode, proc.stdout.strip(), proc.stderr.strip()


def has_cuda_gpu():
    code, out, err = run_cmd(["nvidia-smi", "-L"])
    if code == 0 and out:
        return True, out
    return False, err or out or "nvidia-smi not available"


GPU_AVAILABLE, GPU_INFO = has_cuda_gpu()
IIR2D_GPU_READY = False
IIR2D_GPU_REASON = ""

print("GPU_AVAILABLE:", GPU_AVAILABLE)
if GPU_INFO:
    print(GPU_INFO.splitlines()[0])

if GPU_AVAILABLE:
    # Ensure CUDA-capable JAX stack is present.
    try:
        import jax
        devs = jax.devices()
        has_gpu_backend = any(getattr(d, "platform", "") in ("gpu", "cuda") for d in devs)
    except Exception:
        has_gpu_backend = False

    if not has_gpu_backend:
        subprocess.check_call([
            sys.executable,
            "-m",
            "pip",
            "install",
            "-q",
            "--upgrade",
            "jax==0.4.38",
            "jaxlib==0.4.38",
            "jax-cuda13-plugin",
            "jax-cuda13-pjrt",
        ])

    # Fresh import check after potential install.
    import jax
    devs = jax.devices()
    has_gpu_backend = any(getattr(d, "platform", "") in ("gpu", "cuda") for d in devs)
    print("JAX devices:", devs)

    if has_gpu_backend:
        nvcc_path = shutil.which("nvcc")
        cmake_path = shutil.which("cmake")

        if nvcc_path and cmake_path:
            build_dir = Path("build_colab_gpu")
            subprocess.check_call(["cmake", "-S", ".", "-B", str(build_dir), "-DCMAKE_BUILD_TYPE=Release"])
            subprocess.check_call(["cmake", "--build", str(build_dir), "-j"])

            candidates = list(build_dir.rglob("libiir2d_jax.so"))
            if not candidates:
                raise RuntimeError("Built successfully but libiir2d_jax.so was not found.")

            lib_src = candidates[0]
            lib_dst = Path("python") / "iir2d_jax" / "libiir2d_jax.so"
            lib_dst.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(lib_src, lib_dst)
            IIR2D_GPU_READY = True
            IIR2D_GPU_REASON = f"ready ({lib_dst})"
        else:
            IIR2D_GPU_REASON = "CUDA GPU found but nvcc/cmake missing; cannot build iir2d_jax custom-call library."
    else:
        IIR2D_GPU_REASON = "CUDA GPU found, but JAX GPU backend is unavailable in this runtime."
else:
    IIR2D_GPU_REASON = "No CUDA GPU detected; notebook will run CPU reference path."

print("IIR2D_GPU_READY:", IIR2D_GPU_READY)
print("IIR2D_GPU_REASON:", IIR2D_GPU_REASON)

In [None]:
import math
import time
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from iir2d_jax import iir2d
from scripts.iir2d_cpu_reference import iir2d_cpu_reference

plt.style.use("dark_background")
np.random.seed(42)

FILTER_LABELS = {
    1: "F1 EMA",
    2: "F2 SOS",
    3: "F3 Biquad",
    4: "F4 SOS",
    5: "F5 FB First",
    6: "F6 Deriche-ish",
    7: "F7 Sharper EMA",
    8: "F8 State",
}

DEFAULT_SHOWCASE_FILTERS = [1, 2, 5, 6, 7]  # artifact-safe defaults for visual tour
BLOCKSCAN_FILTERS = [3, 4, 8]               # shown in dedicated checkerboard section


def _has_gpu_device():
    return any(getattr(d, "platform", "") in ("gpu", "cuda") for d in jax.devices())


GPU_DEMO_ACTIVE = bool(IIR2D_GPU_READY) and _has_gpu_device()
if GPU_AVAILABLE and not GPU_DEMO_ACTIVE:
    raise RuntimeError(
        "GPU was detected, but demo backend is not on GPU. "
        f"bootstrap={IIR2D_GPU_READY}, devices={jax.devices()}, reason={IIR2D_GPU_REASON}"
    )

print("Demo backend:", "GPU" if GPU_DEMO_ACTIVE else "CPU reference")
print("JAX devices:", jax.devices())


def normalize01(x):
    x = np.asarray(x, dtype=np.float32)
    lo, hi = float(x.min()), float(x.max())
    return (x - lo) / (hi - lo + 1e-8)


def show(img, title="", ax=None):
    if ax is None:
        _, ax = plt.subplots(figsize=(6, 4))
    ax.imshow(np.clip(img, 0.0, 1.0))
    ax.set_title(title)
    ax.axis("off")


def _jax_dtype_for_precision(precision):
    if precision == "f64":
        return jnp.float64
    return jnp.float32


def apply_iir2d_rgb_reference(img, filter_id=4, border_mode="mirror", precision="f32", border_const=0.0):
    out_dtype = np.float64 if precision == "f64" else np.float32
    out = np.empty(img.shape, dtype=out_dtype)
    for c in range(3):
        out[..., c] = iir2d_cpu_reference(
            img[..., c],
            filter_id=filter_id,
            border_mode=border_mode,
            border_const=float(border_const),
            precision=precision,
        )
    return np.clip(out, 0.0, 1.0)


def apply_iir2d_rgb_gpu(img, filter_id=4, border_mode="mirror", precision="f32", border_const=0.0):
    if not GPU_DEMO_ACTIVE:
        raise RuntimeError("GPU demo backend is not active in this runtime.")

    out_dtype = np.float64 if precision == "f64" else np.float32
    jdtype = _jax_dtype_for_precision(precision)
    out = np.empty(img.shape, dtype=out_dtype)

    for c in range(3):
        xc = jnp.asarray(np.asarray(img[..., c], dtype=out_dtype), dtype=jdtype)
        yc = iir2d(
            xc,
            filter_id=int(filter_id),
            border=border_mode,
            border_const=float(border_const),
            precision=precision,
        )
        yc.block_until_ready()
        out[..., c] = np.asarray(yc)

    return np.clip(out, 0.0, 1.0)


def apply_iir2d_rgb(img, filter_id=4, border_mode="mirror", precision="f32", border_const=0.0, backend="auto"):
    if backend == "auto":
        backend = "gpu" if GPU_DEMO_ACTIVE else "reference"
    if backend == "gpu":
        return apply_iir2d_rgb_gpu(img, filter_id, border_mode, precision, border_const)
    if backend in ("reference", "cpu"):
        return apply_iir2d_rgb_reference(img, filter_id, border_mode, precision, border_const)
    raise ValueError(f"Invalid backend {backend!r}; expected gpu/reference/auto")


def make_cosmic_portrait(h=384, w=608):
    y, x = np.mgrid[-1:1:complex(0, h), -1:1:complex(0, w)]
    radial = np.sqrt(x * x + y * y)
    swirl = np.sin(8 * radial - 3 * np.arctan2(y, x))

    bg_r = normalize01(0.2 + 0.1 * np.cos(3 * x) + 0.2 * swirl)
    bg_g = normalize01(0.1 + 0.2 * np.sin(2 * y + 3 * x) + 0.15 * swirl)
    bg_b = normalize01(0.35 + 0.4 * np.cos(2 * radial) + 0.25 * swirl)
    bg = np.stack([bg_r, bg_g, bg_b], axis=-1)

    head = np.exp(-((x / 0.52) ** 2 + ((y + 0.02) / 0.72) ** 2) * 2.5)
    beard = np.exp(-((x / 0.45) ** 2 + ((y - 0.43) / 0.30) ** 2) * 5.0)
    hair = np.exp(-((x / 0.70) ** 2 + ((y + 0.25) / 0.55) ** 2) * 2.6) * (0.6 + 0.4 * np.sin(18 * x + 8 * y))
    eye_l = np.exp(-(((x + 0.18) / 0.07) ** 2 + ((y + 0.06) / 0.05) ** 2) * 8)
    eye_r = np.exp(-(((x - 0.18) / 0.07) ** 2 + ((y + 0.06) / 0.05) ** 2) * 8)

    skin = np.stack([0.95 * head, 0.78 * head, 0.76 * head], axis=-1)
    beard_rgb = np.stack([0.25 * beard, 0.28 * beard, 0.22 * beard], axis=-1)
    hair_rgb = np.stack([0.45 * hair, 0.30 * hair, 0.18 * hair], axis=-1)
    eyes = np.stack([0.45 * (eye_l + eye_r), 0.25 * (eye_l + eye_r), 0.70 * (eye_l + eye_r)], axis=-1)

    return np.clip(0.55 * bg + skin + beard_rgb + hair_rgb + eyes, 0.0, 1.0)


def make_mountain_scroll(h=360, w=620):
    y, x = np.mgrid[0:1:complex(0, h), 0:1:complex(0, w)]
    ridge = 0.45 * np.sin(10 * x + 4 * np.sin(5 * x)) + 0.25 * np.sin(24 * x + 6 * y)
    clouds = 0.4 * np.cos(7 * y + 3 * np.sin(8 * x))
    texture = 0.15 * np.sin(80 * x * y) + 0.12 * np.cos(60 * (x - y))
    base = normalize01(ridge + clouds + texture)

    r = normalize01(base * 0.9 + 0.2 * np.sin(8 * y))
    g = normalize01(base * 0.8 + 0.25 * np.cos(10 * x))
    b = normalize01(base * 1.05 + 0.3 * np.cos(6 * y))

    img = np.stack([r, g, b], axis=-1)

    for cx in [0.28, 0.5, 0.72]:
        mask = (np.abs(x - cx) < 0.012) & (y > 0.38) & (y < 0.70)
        img[mask] *= np.array([0.25, 0.23, 0.22])
        roof = (np.abs(x - cx) < 0.02) & (y > 0.34) & (y < 0.40) & (np.abs(x - cx) < (0.02 - (y - 0.34) * 0.3))
        img[roof] *= np.array([0.35, 0.2, 0.2])

    return np.clip(img, 0.0, 1.0)


def make_microbe_swarm(h=390, w=610, n=180):
    rng = np.random.default_rng(7)
    y, x = np.mgrid[0:h, 0:w]
    x = x.astype(np.float32)
    y = y.astype(np.float32)
    img = np.zeros((h, w, 3), dtype=np.float32)

    xn = (x / w) * 2 - 1
    yn = (y / h) * 2 - 1
    rad = np.sqrt(xn * xn + yn * yn)
    img[..., 1] = normalize01(np.exp(-(rad ** 2) * 1.8) * 0.8)
    img[..., 2] = normalize01(np.exp(-(rad ** 2) * 2.6) * 0.5)

    for _ in range(n):
        cx = rng.uniform(0, w)
        cy = rng.uniform(0, h)
        rx = rng.uniform(5, 16)
        ry = rng.uniform(5, 16)
        blob = np.exp(-(((x - cx) / rx) ** 2 + ((y - cy) / ry) ** 2))
        color = np.array([
            rng.uniform(0.3, 1.0),
            rng.uniform(0.2, 0.95),
            rng.uniform(0.1, 0.7),
        ], dtype=np.float32)
        img += blob[..., None] * color[None, None, :] * 0.42

    chips = (((x // 32 + y // 32) % 2) == 0).astype(np.float32)
    img += np.stack([0.1 * chips, 0.07 * chips, 0.02 * chips], axis=-1)

    return np.clip(normalize01(img), 0.0, 1.0)

In [None]:
scenes = {
    "Cosmic Portrait": make_cosmic_portrait(),
    "Mountain Scroll": make_mountain_scroll(),
    "Microbe Swarm": make_microbe_swarm(),
}

fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for ax, (name, img) in zip(axes, scenes.items()):
    show(img, name, ax=ax)
plt.tight_layout()

## Default Filter Tour (GPU)

This gallery runs on the JAX + `iir2d_jax` GPU backend when CUDA is available.
CPU is kept only for explicit reference/comparison sections.

In [None]:
scene_name = "Cosmic Portrait"
src_img = scenes[scene_name]
showcase_filters = DEFAULT_SHOWCASE_FILTERS

n_panels = len(showcase_filters) + 1
cols = 3
rows = math.ceil(n_panels / cols)
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.4 * rows))
flat_axes = np.array(axes).reshape(-1)

show(src_img, "Original", ax=flat_axes[0])
for idx, fid in enumerate(showcase_filters, start=1):
    out = apply_iir2d_rgb(src_img, filter_id=fid, border_mode="mirror", precision="f32")
    show(out, FILTER_LABELS[fid], ax=flat_axes[idx])

for ax in flat_axes[n_panels:]:
    ax.axis("off")

plt.suptitle(f"IIR2D Default Visual Tour — {scene_name}", fontsize=16)
plt.tight_layout()

## Checkerboard Lab (Interactive)

This section intentionally reproduces patch/checkerboard behavior for block-scan filters (`F3/F4/F8`).
Use the sliders to adjust image size and block width so you can see how boundary alignment drives the pattern.

In [None]:
import scripts.iir2d_cpu_reference as cpu_ref


def _get_precision_cfg(precision="f32"):
    if precision == "f64":
        return np.float64, np.float64
    if precision == "mixed":
        return np.float32, np.float64
    return np.float32, np.float32


def _apply_biquad_rows(image_2d, coeffs, block_width, border_mode_i, bconst, acc, io_dtype):
    b0, b1, b2, a1, a2 = coeffs
    out = np.empty_like(image_2d, dtype=io_dtype)
    for r in range(image_2d.shape[0]):
        out[r] = cpu_ref._row_biquad_scan_contract(
            image_2d[r],
            acc(b0),
            acc(b1),
            acc(b2),
            acc(a1),
            acc(a2),
            border_mode_i,
            bconst,
            acc,
            io_dtype,
            block_width=int(block_width),
        )
    return out


def iir2d_blockscan_filter_2d(channel, filter_id=8, block_width=256, border_mode="mirror", precision="f32", border_const=0.0):
    if filter_id not in (3, 4, 8):
        return iir2d_cpu_reference(channel, filter_id=filter_id, border_mode=border_mode, precision=precision, border_const=border_const)

    border_mode_i = cpu_ref._normalize_border_mode(border_mode)
    io_dtype, acc = _get_precision_cfg(precision)
    x = np.asarray(channel, dtype=io_dtype, order="C")
    bconst = acc(border_const)

    c_a = (0.2, 0.2, 0.2, 0.3, -0.1)
    c_b = (0.3, 0.1, 0.1, 0.2, -0.05)

    def pass_2d(inp):
        if filter_id in (3, 8):
            return _apply_biquad_rows(inp, c_a, block_width, border_mode_i, bconst, acc, io_dtype)
        tmp = _apply_biquad_rows(inp, c_a, block_width, border_mode_i, bconst, acc, io_dtype)
        return _apply_biquad_rows(tmp, c_b, block_width, border_mode_i, bconst, acc, io_dtype)

    row = pass_2d(x)
    col_t = pass_2d(np.ascontiguousarray(row.T))
    return np.ascontiguousarray(col_t.T)


def apply_blockscan_rgb(img, filter_id=8, block_width=256, border_mode="mirror", precision="f32", border_const=0.0):
    out_dtype = np.float64 if precision == "f64" else np.float32
    out = np.empty(img.shape, dtype=out_dtype)
    for c in range(3):
        out[..., c] = iir2d_blockscan_filter_2d(
            img[..., c],
            filter_id=filter_id,
            block_width=block_width,
            border_mode=border_mode,
            precision=precision,
            border_const=border_const,
        )
    return np.clip(out, 0.0, 1.0)


def make_boundary_stress_scene(height=512, width=512):
    y, x = np.mgrid[-1:1:complex(0, height), -1:1:complex(0, width)]
    radial = np.sqrt(x * x + y * y)
    ang = np.arctan2(y, x)

    r = normalize01(0.55 * np.cos(6 * ang) + 0.45 * np.sin(4 * radial + 8 * x))
    g = normalize01(0.50 * np.sin(7 * ang + 5 * radial) + 0.35 * np.cos(10 * y))
    b = normalize01(0.40 * np.cos(9 * radial) + 0.50 * np.sin(5 * x - 3 * y))

    checker = (((np.floor((x + 1.0) * width / 32) + np.floor((y + 1.0) * height / 32)) % 2) * 0.08).astype(np.float32)
    img = np.stack([r + checker, g + checker, b + checker], axis=-1)
    return np.clip(img, 0.0, 1.0)


def boundary_discontinuity(img, block_width):
    h, w = img.shape[:2]
    vals = []
    for bx in range(block_width, w, block_width):
        vals.append(np.mean(np.abs(img[:, bx, :] - img[:, bx - 1, :])))
    for by in range(block_width, h, block_width):
        vals.append(np.mean(np.abs(img[by, :, :] - img[by - 1, :, :])))
    return float(np.mean(vals)) if vals else 0.0


def render_checkerboard_lab(filter_id=8, block_width=256, width=512, height=512, border_mode="mirror"):
    scene = make_boundary_stress_scene(height=height, width=width)
    blockscan = apply_blockscan_rgb(scene, filter_id=filter_id, block_width=block_width, border_mode=border_mode, precision="f32")
    stable = apply_iir2d_rgb(scene, filter_id=2, border_mode=border_mode, precision="f32")

    disc = boundary_discontinuity(blockscan, block_width)

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    show(scene, f"Original {width}x{height}", ax=axes[0])
    show(blockscan, f"{FILTER_LABELS[filter_id]} block={block_width} | boundary jump={disc:.5f}", ax=axes[1])
    show(stable, "Stable reference (F2)", ax=axes[2])

    for bx in range(block_width, width, block_width):
        axes[1].axvline(bx - 0.5, color="white", alpha=0.25, linewidth=0.8)
    for by in range(block_width, height, block_width):
        axes[1].axhline(by - 0.5, color="white", alpha=0.25, linewidth=0.8)

    plt.tight_layout()
    plt.show()


try:
    import ipywidgets as widgets

    widgets.interact(
        render_checkerboard_lab,
        filter_id=widgets.Dropdown(options=[3, 4, 8], value=8, description="filter"),
        block_width=widgets.IntSlider(value=256, min=32, max=512, step=32, description="block"),
        width=widgets.IntSlider(value=512, min=256, max=896, step=64, description="width"),
        height=widgets.IntSlider(value=512, min=256, max=896, step=64, description="height"),
        border_mode=widgets.Dropdown(options=["mirror", "clamp", "wrap", "constant"], value="mirror", description="border"),
    )
except Exception:
    for bw in (128, 256, 384):
        render_checkerboard_lab(filter_id=8, block_width=bw, width=512, height=512, border_mode="mirror")

In [None]:
src_img = scenes["Mountain Scroll"]
border_modes = ["clamp", "mirror", "wrap", "constant"]

fig, axes = plt.subplots(1, 4, figsize=(20, 5))
for ax, border in zip(axes, border_modes):
    out = apply_iir2d_rgb(src_img, filter_id=2, border_mode=border, precision="f32", border_const=0.08)
    show(out, f"Border: {border}", ax=ax)

plt.suptitle("Border Semantics Matter (Filter 2)", fontsize=15)
plt.tight_layout()

In [None]:
src_img = scenes["Microbe Swarm"]

out_f32 = apply_iir2d_rgb(src_img, filter_id=2, border_mode="mirror", precision="f32")
out_mixed = apply_iir2d_rgb(src_img, filter_id=2, border_mode="mirror", precision="mixed")
out_f64 = apply_iir2d_rgb(src_img.astype(np.float64), filter_id=2, border_mode="mirror", precision="f64")

delta_mixed = np.abs(out_f32 - out_mixed).mean()
delta_f64 = np.abs(out_f32 - out_f64.astype(np.float32)).mean()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
show(out_f32, "f32", ax=axes[0])
show(out_mixed, f"mixed (mean |d|={delta_mixed:.6f})", ax=axes[1])
show(out_f64, f"f64 (mean |d|={delta_f64:.6f})", ax=axes[2])
plt.suptitle("Precision Modes: Visual + Numeric Drift", fontsize=15)
plt.tight_layout()

In [None]:
def make_videoish_sequence(base, n=12):
    rng = np.random.default_rng(13)
    frames = []
    for t in range(n):
        dx = int(5 * math.sin(2 * math.pi * t / n))
        dy = int(4 * math.cos(2 * math.pi * t / n))
        shifted = np.roll(np.roll(base, dy, axis=0), dx, axis=1)
        noise = rng.normal(0.0, 0.03, size=base.shape).astype(np.float32)
        frame = np.clip(0.9 * shifted + noise, 0.0, 1.0)
        frames.append(frame)
    return frames


def temporal_energy(frames):
    diffs = [np.mean(np.abs(frames[i + 1] - frames[i])) for i in range(len(frames) - 1)]
    return float(np.mean(diffs))


seq_in = make_videoish_sequence(scenes["Cosmic Portrait"], n=12)
seq_out = [apply_iir2d_rgb(f, filter_id=2, border_mode="mirror", precision="f32") for f in seq_in]

e_in = temporal_energy(seq_in)
e_out = temporal_energy(seq_out)

fig, axes = plt.subplots(2, 6, figsize=(22, 8))
for i in range(6):
    show(seq_in[i], f"In t={i}", ax=axes[0, i])
    show(seq_out[i], f"Out t={i}", ax=axes[1, i])
plt.suptitle(
    f"Video-ish sequence (top=raw, bottom=filtered) | mean temporal energy: {e_in:.4f} -> {e_out:.4f}",
    fontsize=14,
)
plt.tight_layout()

In [None]:
x = np.random.default_rng(0).random((256, 256, 3), dtype=np.float32)

results_gpu = []
for fid in range(1, 9):
    t0 = time.perf_counter()
    for _ in range(3):
        _ = apply_iir2d_rgb(x, filter_id=fid, border_mode="mirror", precision="f32", backend="gpu" if GPU_DEMO_ACTIVE else "reference")
    dt = (time.perf_counter() - t0) / 3.0
    results_gpu.append((fid, dt * 1000.0))

print("IIR2D demo timing (primary path):")
print("  backend:", "GPU" if GPU_DEMO_ACTIVE else "CPU reference (no GPU runtime)")
for fid, ms in results_gpu:
    print(f"  Filter {fid}: {ms:8.2f} ms")

# Explicit CPU reference sample for parity context.
ref_t0 = time.perf_counter()
_ = apply_iir2d_rgb_reference(x, filter_id=2, border_mode="mirror", precision="f32")
ref_ms = (time.perf_counter() - ref_t0) * 1000.0
print(f"Reference-only sample (CPU, Filter 2): {ref_ms:.2f} ms")

In [None]:
# GPU execution sanity check + reference parity spot-check
if GPU_AVAILABLE and not GPU_DEMO_ACTIVE:
    raise RuntimeError(
        "GPU detected but notebook demos are not running on GPU. "
        f"reason={IIR2D_GPU_REASON}"
    )

x_np = scenes["Cosmic Portrait"][..., 0].astype(np.float32)

y_demo = apply_iir2d_rgb(
    scenes["Cosmic Portrait"],
    filter_id=2,
    border_mode="mirror",
    precision="f32",
    backend="gpu" if GPU_DEMO_ACTIVE else "reference",
)

print("Primary demo backend:", "GPU" if GPU_DEMO_ACTIVE else "CPU reference")
print("JAX devices:", jax.devices())
print("Demo output stats:", float(y_demo.min()), float(y_demo.max()), float(y_demo.mean()))

# CPU is reference-only: parity sample against one channel.
y_ref = iir2d_cpu_reference(x_np, filter_id=2, border_mode="mirror", precision="f32")
y_gpu_or_ref = apply_iir2d_rgb(
    np.stack([x_np, x_np, x_np], axis=-1),
    filter_id=2,
    border_mode="mirror",
    precision="f32",
    backend="gpu" if GPU_DEMO_ACTIVE else "reference",
)[..., 0]
mean_abs = float(np.mean(np.abs(y_gpu_or_ref - y_ref)))
print(f"Reference parity spot-check (mean abs diff): {mean_abs:.6e}")

## Takeaways

- This notebook is GPU-first: visual demos run on JAX + `iir2d_jax` when CUDA is available.
- CPU is used only as an explicit reference/parity path.
- Border semantics are not cosmetic; they materially change output behavior.
- For video-like sequences, temporal smoothness effects are easy to observe frame-to-frame.
- Checkerboard-like patches in `F3/F4/F8` come from block-scan boundary composition and are isolated to the interactive lab.

If you want, the next iteration can add:
1. live side-by-side video scrubber with per-frame metrics,
2. direct export of selected demo cells as README-ready PNG/MP4,
3. auto-generated benchmark claims packet from notebook runs.