<a href="https://colab.research.google.com/github/jamessutton600613-png/GC/blob/main/Untitled198.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# === GQR TDSE — HCIBE sweep with A baseline (GPU-safe) ===
# Clean writer + combined multi-panel MP4
# - Fixes float->uint8 + macro block warnings
# - Adds grid movie of all variants at once (3x3)
# ----------------------------------------------------------

!pip -q install imageio-ffmpeg

import os, json, itertools, math, csv, time, glob
from dataclasses import dataclass

# --- backend select (GPU if available) ---
USE_GPU = False
try:
    import cupy as cp
    _ = cp.zeros((1,))
    USE_GPU = True
except Exception:
    cp = None

import numpy as np
xp = cp if USE_GPU else np

import matplotlib
matplotlib.use("Agg")  # headless-safe
import matplotlib.pyplot as plt
import imageio.v2 as iio
from PIL import Image, ImageOps

print(f"[backend] {'CuPy (GPU)' if USE_GPU else 'NumPy (CPU)'}")

# ---------- utilities ----------
def ensure_dir(p): os.makedirs(p, exist_ok=True)

def to_cpu(a):
    """Return a NumPy array for plotting/saving, no matter the backend."""
    if USE_GPU and isinstance(a, cp.ndarray):
        return cp.asnumpy(a)
    return np.asarray(a)

def _fig_rgb(fig):
    """Grab RGBA from canvas at fixed size, return uint8 RGB HxWx3."""
    fig.canvas.draw()
    rgba = np.asarray(fig.canvas.buffer_rgba())
    return rgba[:, :, :3].copy()

def _norm_to_uint8_rgb(a, vmax=None):
    """Normalize ndarray to [0,1] and return uint8 RGB (grayscale in 3 channels)."""
    f = to_cpu(a).astype(np.float32)
    if vmax is None:
        vmax = np.percentile(f, 99.5)
    vmax = max(vmax, 1e-12)
    f = np.clip(f / vmax, 0.0, 1.0)
    u = (255.0 * f).astype(np.uint8)
    return np.dstack([u, u, u])

def _resize_to(frame_u8, width, height):
    """Resize uint8 RGB to exact (width, height) with nearest neighbor (fast, crisp axes)."""
    im = Image.fromarray(frame_u8, mode="RGB").resize((width, height), resample=Image.NEAREST)
    return np.asarray(im)

def _write_mp4(frames, out_path, fps=12, macro=16):
    """Write a sequence of uint8 RGB frames to MP4 with macro-block safety."""
    with iio.get_writer(out_path, fps=fps, codec="libx264",
                        macro_block_size=macro, pixelformat="yuv420p") as w:
        for fr in frames:
            w.append_data(fr)

# ---------- montage helper (optional stills) ----------
def montage(img_paths, ncols, out_path, dpi=120):
    n = len(img_paths)
    nrows = math.ceil(n / ncols)
    fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows), dpi=dpi)
    axes = np.atleast_1d(axes).ravel()
    for ax, p in zip(axes, img_paths):
        ax.imshow(plt.imread(p))
        ax.axis('off')
        ax.set_title(os.path.basename(p), fontsize=8)
    for ax in axes[len(img_paths):]:
        ax.axis('off')
    plt.tight_layout()
    fig.savefig(out_path, bbox_inches='tight')
    plt.close(fig)

# ---------- grid & physics ----------
@dataclass
class Grid:
    nx: int = 256
    ny: int = 256
    Lx: float = 20.0
    Ly: float = 20.0
    hbar: float = 1.0
    m: float = 1.0

    def build(self):
        x = xp.linspace(-self.Lx/2, self.Lx/2, self.nx, endpoint=False)
        y = xp.linspace(-self.Ly/2, self.Ly/2, self.ny, endpoint=False)
        dx = x[1] - x[0]
        dy = y[1] - y[0]
        kx = 2*xp.pi*(xp.fft.fftfreq(self.nx, d=float(dx)))
        ky = 2*xp.pi*(xp.fft.fftfreq(self.ny, d=float(dy)))
        KX, KY = xp.meshgrid(kx, ky, indexing='xy')
        X,  Y  = xp.meshgrid(x,  y,  indexing='xy')
        K2 = KX**2 + KY**2
        return X, Y, K2, dx, dy

# ---------- potentials ----------
def potential_A(X, Y, A_height=8.0, A_width=1.2):
    Vx = A_height * xp.exp(-(X/A_width)**2)
    Vy = 0.02 * Y**2
    return Vx + Vy

def gate_H_funnel(X, Y, H_depth=4.0, H_width=3.0, aperture_y=0.0, aperture_x0=-1.0):
    taper = 0.5 + 0.5*xp.tanh((X - aperture_x0)/2.0)
    sigma = 0.7 + (1.5 - 0.7)*(1.0 - taper)
    d = (Y - aperture_y)/sigma
    return -H_depth * xp.exp(-0.5 * d**2) * (0.2 + 0.8*taper)

def gate_C_field(X, Y, C_strength=0.6, theta_deg=0.0):
    th = xp.deg2rad(theta_deg)
    Ex, Ey = C_strength*xp.cos(th), C_strength*xp.sin(th)
    return -(Ex*X + Ey*Y)

def gate_B_tilt(X, Y, B_strength=0.35):
    return -B_strength * (0.3*X + 1.0*Y)

def gate_I_pcet_proxy(X, Y, I_depth=1.4, r0=3.0):
    return -I_depth * xp.exp(-((X - r0)**2 + (Y/1.2)**2) / 1.2**2)

# ---------- dephasing ----------
def apply_dephasing(psi, E_gamma=0.0, rng=None):
    if E_gamma <= 0.0:
        return psi
    if USE_GPU:
        phase = cp.random.normal(0.0, E_gamma, psi.shape, dtype=psi.real.dtype)
        return psi * xp.exp(1j * phase)
    else:
        phase = np.random.normal(0.0, E_gamma, psi.shape)
        return psi * np.exp(1j * phase)

# ---------- initial packet ----------
def gaussian_packet(X, Y, x0=-6.5, y0=0.0, px=2.0, py=0.0, sx=1.2, sy=1.2, hbar=1.0):
    phase = (px*X + py*Y)/hbar
    amp   = xp.exp(-0.5*((X-x0)/sx)**2 - 0.5*((Y-y0)/sy)**2)
    psi0  = amp * xp.exp(1j*phase)
    norm = xp.sqrt(xp.sum(xp.abs(psi0)**2))
    return psi0 / norm

# ---------- propagate ----------
def propagate_tdse(grid: Grid, V_func, T_final=20.0, dt=0.01, dephaser=None, save_every=100):
    hbar, m = grid.hbar, grid.m
    X, Y, K2, dx, dy = grid.build()
    V = V_func(X, Y)
    K_phase = xp.exp(-1j * (hbar**2 * K2) * dt / (2.0*m*hbar))
    psi = gaussian_packet(X, Y, hbar=hbar)

    frames = []
    nsteps = int(T_final/dt)
    for n in range(nsteps):
        psi *= xp.exp(-1j * V * dt / (2*hbar))
        if dephaser is not None:
            psi = dephaser(psi)
        psi_k = xp.fft.fft2(psi)
        psi_k *= K_phase
        psi = xp.fft.ifft2(psi_k)
        psi *= xp.exp(-1j * V * dt / (2*hbar))

        if (n % save_every) == 0 or n == nsteps-1:
            frames.append(xp.abs(psi)**2)
    return (X, Y), V, frames, psi

# ---------- metrics ----------
def aperture_mask(X, Y, x_gate=2.0, y_halfwidth=1.0):
    return (X > x_gate) & (xp.abs(Y) < y_halfwidth)

def compute_metrics(frames, mask, threshold=5e-4):
    masses = [xp.sum(f[mask]) for f in frames]
    open_frac = float(xp.mean(xp.array([m > threshold for m in masses])))
    int_flux  = float(xp.sum(xp.array(masses)))
    alpha = 0.1
    ema = 0.0
    for m in masses:
        ema = float(alpha*float(m) + (1-alpha)*ema)
    return dict(aperture_fraction=open_frac, integrated_flux=int_flux, ema_density=ema)

# ---------- builder ----------
def make_V_func(enable_H=False, enable_C=False, enable_I=False, enable_B=False, enable_E=False,
                A_kwargs=None, H_kwargs=None, C_kwargs=None, I_kwargs=None, B_kwargs=None,
                E_gamma=0.0):
    A_kwargs = A_kwargs or {}
    H_kwargs = H_kwargs or {}
    C_kwargs = C_kwargs or {}
    I_kwargs = I_kwargs or {}
    B_kwargs = B_kwargs or {}

    def V_total(X, Y):
        V = potential_A(X, Y, **A_kwargs)
        if enable_H: V += gate_H_funnel(X, Y, **H_kwargs)
        if enable_C: V += gate_C_field(X, Y, **C_kwargs)
        if enable_B: V += gate_B_tilt(X, Y, **B_kwargs)
        if enable_I: V += gate_I_pcet_proxy(X, Y, **I_kwargs)
        return V

    dephaser = (lambda psi: apply_dephasing(psi, E_gamma=E_gamma)) if enable_E else None
    return V_total, dephaser

# ---------- constants for clean frames ----------
# Single-panel figure/frame size divisible by 16
FIG_DPI = 128
FIGSIZE = (6.0, 4.0)                 # 6*128=768, 4*128=512
FRAME_W, FRAME_H = 768, 512          # exact pixels

def _plot_density_frame(tag, f_np, extent, frame_idx, vmax):
    """Render |psi|^2 with axes into a fixed-size RGB uint8 array (FRAME_W x FRAME_H)."""
    fig, ax = plt.subplots(figsize=FIGSIZE, dpi=FIG_DPI)
    im = ax.imshow(f_np.T, origin='lower', extent=extent, vmin=0, vmax=vmax)
    ax.set_title(f"|psi|^2: {tag}  frame {frame_idx:03d}")
    ax.set_xlabel('x'); ax.set_ylabel('y')
    fig.colorbar(im, ax=ax, shrink=0.8)
    rgb = _fig_rgb(fig)
    plt.close(fig)
    # enforce exact size (should already match, but be safe)
    if rgb.shape[1] != FRAME_W or rgb.shape[0] != FRAME_H:
        rgb = _resize_to(rgb, FRAME_W, FRAME_H)
    return rgb

# ---------- save per-condition outputs (PNG + MP4) ----------
def save_condition_outputs(tag, out_dir, XY, V, frames, vmax=None, write_mp4=True, fps=12):
    X, Y = XY
    cond = os.path.join(out_dir, tag); ensure_dir(cond)
    Xc, Yc = to_cpu(X), to_cpu(Y)
    extent = [float(Xc.min()), float(Xc.max()), float(Yc.min()), float(Yc.max())]

    # save potential (PNG) at fixed size
    fig, ax = plt.subplots(figsize=FIGSIZE, dpi=FIG_DPI)
    im = ax.imshow(to_cpu(V).T, origin='lower', extent=extent)
    ax.set_title(f"Potential: {tag}")
    ax.set_xlabel('x'); ax.set_ylabel('y')
    fig.colorbar(im, ax=ax, shrink=0.8)
    fig_rgb = _fig_rgb(fig)
    plt.close(fig)
    if fig_rgb.shape[1] != FRAME_W or fig_rgb.shape[0] != FRAME_H:
        fig_rgb = _resize_to(fig_rgb, FRAME_W, FRAME_H)
    iio.imwrite(os.path.join(cond, f"{tag}_V.png"), fig_rgb)

    # frames -> PNGs (uint8) with fixed pixel size
    imgs = []
    v_local = vmax
    if v_local is None:
        v_local = max(np.percentile(to_cpu(frames[0]), 99.5), 1e-8)

    for i, f in enumerate(frames):
        f_np = to_cpu(f).astype(np.float32)
        rgb = _plot_density_frame(tag, f_np, extent, i, v_local)
        fp = os.path.join(cond, f"{tag}_psi2_{i:03d}.png")
        iio.imwrite(fp, rgb)  # already uint8 RGB
        imgs.append(fp)

    # montage still (optional)
    montage(imgs, ncols=5, out_path=os.path.join(cond, f"{tag}_montage.png"), dpi=120)

    # MP4 from PNGs (clean dimensions/pixfmt)
    if write_mp4:
        frames_u8 = [iio.imread(p) for p in imgs]
        mp4_path  = os.path.join(cond, f"{tag}.mp4")
        _write_mp4(frames_u8, mp4_path, fps=fps, macro=16)
        print(f"[{tag}] wrote {mp4_path}")

# ---------- combined 3x3 grid movie ----------
def make_grid_movie(out_dir, condition_tags, grid_shape=(3,3), fps=12, out_name="Fig2_grid.mp4"):
    """
    Reads each condition's saved PNG frames and tiles them into an Nrows x Ncols grid movie.
    Assumes all conditions wrote the same number of frames.
    """
    nrows, ncols = grid_shape
    assert len(condition_tags) == nrows*ncols, "grid size must match number of conditions"

    # read per-condition frame lists
    cond_frames = []
    n_frames_min = None
    for tag in condition_tags:
        cond_dir = os.path.join(out_dir, tag)
        files = sorted(glob.glob(os.path.join(cond_dir, f"{tag}_psi2_*.png")))
        if not files:
            raise RuntimeError(f"No frames found for condition {tag} in {cond_dir}")
        cond_frames.append(files)
        n_frames_min = len(files) if n_frames_min is None else min(n_frames_min, len(files))

    # tile size (each panel == the per-condition PNG size)
    sample = iio.imread(cond_frames[0][0])
    h, w, _ = sample.shape
    # whole canvas size (divisible by 16 guaranteed if each tile is, and ours is 768x512)
    canvas_w = w*ncols
    canvas_h = h*nrows

    # write the tiled video
    out_path = os.path.join(out_dir, out_name)
    with iio.get_writer(out_path, fps=fps, codec='libx264',
                        macro_block_size=16, pixelformat='yuv420p') as w:
        for i in range(n_frames_min):
            tiles = []
            for r in range(nrows):
                row_imgs = []
                for c in range(ncols):
                    idx = r*ncols + c
                    img = iio.imread(cond_frames[idx][i])
                    # small title strip with the tag on first frame? (skip for speed)
                    row_imgs.append(img)
                tiles.append(np.concatenate(row_imgs, axis=1))
            canvas = np.concatenate(tiles, axis=0)
            # enforce exact canvas size (safety)
            if canvas.shape[1] != canvas_w or canvas.shape[0] != canvas_h:
                canvas = _resize_to(canvas, canvas_w, canvas_h)
            w.append_data(canvas)

    print(f"[grid] wrote {out_path}  ({nrows}x{ncols}, frames={n_frames_min}, size={canvas_w}x{canvas_h})")

# ---------- RUN ----------
out_dir = "gqr_HCIBE_fig2"
ensure_dir(out_dir)

grid = Grid(nx=256, ny=256, Lx=20.0, Ly=20.0, hbar=1.0, m=1.0)

T_final    = 20.0
dt         = 0.01
save_every = 100  # every 1.0 time unit

A_params = dict(A_height=8.0, A_width=1.2)
H_params = dict(H_depth=4.0, H_width=3.0, aperture_y=0.0, aperture_x0=-1.0)
C_params = dict(C_strength=0.6, theta_deg=0.0)
B_params = dict(B_strength=0.35)
I_params = dict(I_depth=1.4, r0=3.0)
E_gamma  = 0.03

# Exactly nine conditions for a 3x3 grid
conditions = [
    ('A_H',     dict(H=True)),
    ('A_C',     dict(C=True)),
    ('A_I',     dict(I=True)),
    ('A_B',     dict(B=True)),
    ('A_HE',    dict(H=True, E=True)),
    ('A_HC',    dict(H=True, C=True)),
    ('A_HCI',   dict(H=True, C=True, I=True)),
    ('A_HCB',   dict(H=True, C=True, B=True)),
    ('A_HCIBE', dict(H=True, C=True, I=True, B=True, E=True)),
]

csv_path = os.path.join(out_dir, "metrics_summary.csv")
with open(csv_path, "w", newline="") as fcsv:
    writer = csv.writer(fcsv)
    writer.writerow(["condition","aperture_fraction","integrated_flux","ema_density"])

    for tag, flags in conditions:
        print(f"[run] {tag}")
        V_func, dephaser = make_V_func(
            enable_H=flags.get('H',False),
            enable_C=flags.get('C',False),
            enable_I=flags.get('I',False),
            enable_B=flags.get('B',False),
            enable_E=flags.get('E',False),
            A_kwargs=A_params, H_kwargs=H_params, C_kwargs=C_params,
            I_kwargs=I_params, B_kwargs=B_params, E_gamma=E_gamma
        )
        XY, V, frames, psi = propagate_tdse(
            grid, V_func, T_final=T_final, dt=dt, dephaser=dephaser, save_every=save_every
        )
        mask = aperture_mask(*XY, x_gate=2.0, y_halfwidth=1.0)
        metrics = compute_metrics(frames, mask, threshold=5e-4)
        writer.writerow([tag, metrics["aperture_fraction"], metrics["integrated_flux"], metrics["ema_density"]])
        save_condition_outputs(tag, out_dir, XY, V, frames, vmax=None, write_mp4=True, fps=12)

print("Done writing individual movies.")
print("Metrics CSV:", csv_path)

# ---- Make the combined 3x3 grid movie (synchronized) ----
grid_tags = [t for t,_ in conditions]  # 9 tags = 3x3
make_grid_movie(out_dir, grid_tags, grid_shape=(3,3), fps=12, out_name="Fig2_grid.mp4")

print("All done. Outputs in:", out_dir)

[backend] NumPy (CPU)
[run] A_H
[A_H] wrote gqr_HCIBE_fig2/A_H/A_H.mp4
[run] A_C
[A_C] wrote gqr_HCIBE_fig2/A_C/A_C.mp4
[run] A_I
[A_I] wrote gqr_HCIBE_fig2/A_I/A_I.mp4
[run] A_B
[A_B] wrote gqr_HCIBE_fig2/A_B/A_B.mp4
[run] A_HE
[A_HE] wrote gqr_HCIBE_fig2/A_HE/A_HE.mp4
[run] A_HC
[A_HC] wrote gqr_HCIBE_fig2/A_HC/A_HC.mp4
[run] A_HCI
[A_HCI] wrote gqr_HCIBE_fig2/A_HCI/A_HCI.mp4
[run] A_HCB
[A_HCB] wrote gqr_HCIBE_fig2/A_HCB/A_HCB.mp4
[run] A_HCIBE
[A_HCIBE] wrote gqr_HCIBE_fig2/A_HCIBE/A_HCIBE.mp4
Done writing individual movies.
Metrics CSV: gqr_HCIBE_fig2/metrics_summary.csv
[grid] wrote gqr_HCIBE_fig2/Fig2_grid.mp4  (3x3, frames=21, size=2304x1536)
All done. Outputs in: gqr_HCIBE_fig2


In [None]:
# === Extract Fig.2, frame 12 (zero-based) as PNG ===
# - First tries to read from the composite MP4: gqr_fig2_hires/Fig2_grid_luxury.mp4
# - If not present, rebuilds the same 3x3 layout from per-condition PNGs at frame 012.
# - Writes: gqr_fig2_hires/Fig2_frame012_grid.png

import os, glob
import numpy as np
from PIL import Image, ImageOps, ImageDraw, ImageFont
import imageio.v2 as iio

ROOT = "gqr_fig2_hires"
COMPOSITE_MP4 = os.path.join(ROOT, "Fig2_grid_luxury.mp4")
OUT_PNG = os.path.join(ROOT, "Fig2_frame012_grid.png")
FRAME_INDEX = 12  # zero-based

# Tags in the same order as the 3x3 composite you ran
TAGS = [
    "A_H",
    "A_C",
    "A_I",
    "A_B",
    "A_HE",
    "A_HC",
    "A_HCI",
    "A_HCB",
    "A_HCIBE",
]

# Tile size you used in the luxury composite (keep consistent for pixel-perfect match)
TILE_W, TILE_H = 1280, 960
GRID_COLS, GRID_ROWS = 3, 3

def save_png(arr_uint8, path):
    im = Image.fromarray(arr_uint8, mode="RGB")
    im.save(path, format="PNG", optimize=True)
    print(f"[write] {path}")

def try_extract_from_mp4():
    if not os.path.exists(COMPOSITE_MP4):
        return False
    try:
        rdr = iio.get_reader(COMPOSITE_MP4)
        n = rdr.count_frames() if hasattr(rdr, "count_frames") else None
        if n is not None and FRAME_INDEX >= n:
            print(f"[warn] Composite has only {n} frames; cannot get index {FRAME_INDEX}.")
            rdr.close()
            return False

        # If random access isn't supported, iterate
        try:
            frame = rdr.get_data(FRAME_INDEX)
        except Exception:
            frame = None
            for i, fr in enumerate(rdr):
                if i == FRAME_INDEX:
                    frame = fr
                    break
        rdr.close()

        if frame is None:
            print("[warn] Could not access the requested frame in MP4.")
            return False

        # Save directly
        frame_rgb = frame[..., :3] if frame.ndim == 3 else np.stack([frame]*3, axis=-1)
        save_png(frame_rgb.astype(np.uint8), OUT_PNG)
        return True
    except Exception as e:
        print(f"[warn] MP4 extraction failed: {e}")
        return False

def rebuild_from_tiles():
    # Assemble from per-condition frame PNGs: <tag>_psi2_012.png
    # Falls back to closest available frame if exact 012 not found.
    tiles = []
    for tag in TAGS:
        pat_exact = os.path.join(ROOT, tag, f"{tag}_psi2_{FRAME_INDEX:03d}.png")
        candidates = sorted(glob.glob(os.path.join(ROOT, tag, f"{tag}_psi2_*.png")))
        if os.path.exists(pat_exact):
            p = pat_exact
        elif candidates:
            # pick closest index
            idxs = [int(os.path.basename(c).split("_")[-1].split(".")[0]) for c in candidates]
            closest = candidates[int(np.argmin(np.abs(np.array(idxs) - FRAME_INDEX)))]
            print(f"[info] {tag}: exact frame not found, using {os.path.basename(closest)}")
            p = closest
        else:
            raise RuntimeError(f"No frames found for {tag}")

        im = Image.open(p).convert("RGB").resize((TILE_W, TILE_H), Image.BILINEAR)
        tiles.append(im)

    # Compose into 3x3
    CANVAS_W, CANVAS_H = TILE_W * GRID_COLS, TILE_H * GRID_ROWS
    canvas = Image.new("RGB", (CANVAS_W, CANVAS_H), (0,0,0))
    for i, im in enumerate(tiles):
        r, c = divmod(i, GRID_COLS)
        canvas.paste(im, (c*TILE_W, r*TILE_H))

    save_png(np.asarray(canvas, dtype=np.uint8), OUT_PNG)

# Try MP4 first; rebuild if needed
if try_extract_from_mp4():
    print("[ok] Extracted frame from composite MP4.")
else:
    print("[fallback] Building the 3x3 frame from per-condition PNGs…")
    rebuild_from_tiles()
    print("[ok] Rebuilt from tiles.")

[fallback] Building the 3x3 frame from per-condition PNGs…


RuntimeError: No frames found for A_H

In [None]:
# === Build Fig.2 (frame 12) from whatever outputs you already have ===
# - Searches recursively under likely roots for per-condition PNGs
# - Tags order matches the 3×3 used earlier
# - Saves: <OUT_ROOT>/Fig2_frame012_grid.png  (labels burned in)

import os, sys, glob, re
import numpy as np
from PIL import Image, ImageDraw, ImageFont

# ---- config you can tweak ----
FRAME_INDEX = 12  # zero-based
TAGS = [
    "A_H",
    "A_C",
    "A_I",
    "A_B",
    "A_HE",
    "A_HC",
    "A_HCI",
    "A_HCB",
    "A_HCIBE",
]

# candidate roots to search (add/remove as needed)
CANDIDATE_ROOTS = [
    "/content/gqr_fig2_hires",
    "/content/gqr_HCIBE_fig2",
    "/content/gqr_deepfast_fixed",
    "/content/gqr_deepfast",
    "/content",
    ".",
]

# output root (will be created if missing)
OUT_ROOT = "/content/gqr_fig2_hires"
os.makedirs(OUT_ROOT, exist_ok=True)
OUT_PNG = os.path.join(OUT_ROOT, "Fig2_frame012_grid.png")

# tile and grid geometry (match your composite look)
TILE_W, TILE_H = 1280, 960
GRID_COLS, GRID_ROWS = 3, 3

# -------------- helpers --------------
def list_pngs_for_tag(tag, roots):
    """Return sorted list of candidate pngs for a tag anywhere under the roots."""
    candidates = []
    # exact pattern first: .../<tag>/<tag>_psi2_###.png
    for root in roots:
        pat = os.path.join(root, "**", tag, f"{tag}_psi2_*.png")
        candidates.extend(glob.glob(pat, recursive=True))
    if not candidates:
        # fallback: any file named <tag>_psi2_###.png anywhere
        for root in roots:
            pat = os.path.join(root, "**", f"{tag}_psi2_*.png")
            candidates.extend(glob.glob(pat, recursive=True))
    # natural sort by frame index if present
    def frame_num(p):
        m = re.search(r"_psi2_(\d{3})\.png$", os.path.basename(p))
        return int(m.group(1)) if m else 10**9
    candidates = sorted(set(candidates), key=frame_num)
    return candidates

def pick_frame_path(candidates, idx):
    """Pick exact frame idx if present, else closest by |n-idx|."""
    exact = [p for p in candidates if re.search(fr"_psi2_{idx:03d}\.png$", os.path.basename(p))]
    if exact:
        return exact[0], idx, True
    # no exact: pick closest
    nums = []
    for p in candidates:
        m = re.search(r"_psi2_(\d{3})\.png$", os.path.basename(p))
        if m:
            nums.append((abs(int(m.group(1)) - idx), int(m.group(1)), p))
    if not nums:
        return None, None, False
    nums.sort()
    _, n_closest, p = nums[0]
    return p, n_closest, False

def load_and_resize(path, w, h):
    im = Image.open(path).convert("RGB")
    return im.resize((w, h), Image.BILINEAR)

def draw_label(img, text, pad=10):
    draw = ImageDraw.Draw(img)
    # try to get a compact default font
    try:
        font = ImageFont.truetype("DejaVuSans.ttf", 36)
    except Exception:
        font = ImageFont.load_default()
    # shadow for readability
    draw.text((pad+1, pad+1), text, fill=(0,0,0), font=font)
    draw.text((pad, pad), text, fill=(255,255,255), font=font)

# -------------- build the grid --------------
tiles = []
report = []

for tag in TAGS:
    cands = list_pngs_for_tag(tag, CANDIDATE_ROOTS)
    if not cands:
        raise RuntimeError(f"No frames found anywhere for tag '{tag}'. "
                           f"Checked roots: {', '.join(CANDIDATE_ROOTS)}")
    p, used_idx, exact = pick_frame_path(cands, FRAME_INDEX)
    if p is None:
        raise RuntimeError(f"No parseable *_psi2_###.png files for tag '{tag}'. "
                           f"Found {len(cands)} unparseable candidates.")
    im = load_and_resize(p, TILE_W, TILE_H)
    draw_label(im, f"{tag}  (frame {used_idx:03d}{'' if exact else ' ~'})")
    tiles.append(im)
    report.append(f"{tag:8s} -> {os.path.relpath(p, start='/content') if os.path.exists('/content') else p}")

# compose 3×3
CANVAS_W, CANVAS_H = TILE_W*GRID_COLS, TILE_H*GRID_ROWS
canvas = Image.new("RGB", (CANVAS_W, CANVAS_H), (0,0,0))
for i, im in enumerate(tiles):
    r, c = divmod(i, GRID_COLS)
    canvas.paste(im, (c*TILE_W, r*TILE_H))

canvas.save(OUT_PNG, format="PNG", optimize=True)
print(f"[write] {OUT_PNG}")

print("\n[resolved inputs]")
for line in report:
    print("  ", line)

[write] /content/gqr_fig2_hires/Fig2_frame012_grid.png

[resolved inputs]
   A_H      -> gqr_HCIBE_fig2/A_H/A_H_psi2_012.png
   A_C      -> gqr_HCIBE_fig2/A_C/A_C_psi2_012.png
   A_I      -> gqr_HCIBE_fig2/A_I/A_I_psi2_012.png
   A_B      -> gqr_HCIBE_fig2/A_B/A_B_psi2_012.png
   A_HE     -> gqr_HCIBE_fig2/A_HE/A_HE_psi2_012.png
   A_HC     -> gqr_HCIBE_fig2/A_HC/A_HC_psi2_012.png
   A_HCI    -> gqr_HCIBE_fig2/A_HCI/A_HCI_psi2_012.png
   A_HCB    -> gqr_HCIBE_fig2/A_HCB/A_HCB_psi2_012.png
   A_HCIBE  -> gqr_HCIBE_fig2/A_HCIBE/A_HCIBE_psi2_012.png
