In [None]:
# ---- 1) Token hook (keep this if not already set) ----
token_outputs = {}
def token_hook(module, inp, out):
    token_outputs["tokens"] = out.detach().cpu()  # (B, N, C)
_ = model.norm.register_forward_hook(token_hook)  # attach once

# ---- 2) Utilities ----
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

GRID = int(model.patch_embed.grid_size[0])  # 14 for vit_tiny_patch16_224

def _to_img(m, H=128, W=128):
    m = m.astype(np.float32)
    mn, mx = m.min(), m.max()
    if mx <= mn + 1e-8:
        m = np.zeros_like(m, dtype=np.float32)
    else:
        m = (m - mn) / (mx - mn)
    im = Image.fromarray((m * 255).astype(np.uint8))
    return im.resize((W, H), resample=Image.BILINEAR)

def _tile(imgs, cols=3):
    w, h = imgs[0].size
    rows = (len(imgs)+cols-1)//cols
    canvas = Image.new("L", (cols*w, rows*h))
    for i, im in enumerate(imgs):
        r, c = divmod(i, cols)
        canvas.paste(im, (c*w, r*h))
    return canvas

# ---- 3) Panel builders from tokens only (no attention) ----
def panels_from_tokens(tokens_B_N_C, grid):
    """
    tokens_B_N_C: (B, N, C) from model.norm forward hook.
    N includes CLS at index 0; patches are 1..grid*grid.
    Returns three np.float32 arrays in [0,1]-ready form (not yet normalized):
      G_cos  : (P,P) cosine Gram of patch tokens
      R_corr : (P,P) Pearson corr of patch tokens across dim C
      CLSmap : (grid,grid) cosine similarity of CLS vs each patch token
    """
    assert tokens_B_N_C.ndim == 3
    B, N, C = tokens_B_N_C.shape
    P = grid * grid
    # Drop CLS for patch grids
    Tpatch = tokens_B_N_C[:, 1:1+P, :]                     # (B, P, C)
    # Average over batch to stabilize
    Tmean  = Tpatch.mean(dim=0)                            # (P, C)

    # 1) Cosine Gram (P,P)
    Tn = F.normalize(Tmean, dim=1)                         # (P, C)
    G = (Tn @ Tn.T).clamp(-1, 1)                           # (P, P)

    # 2) Pearson correlation across C (P,P)
    #   Center per feature, then corr = cov/std/std
    X = Tmean - Tmean.mean(dim=1, keepdim=True)            # (P, C)
    std = X.std(dim=1, keepdim=True) + 1e-6
    Xn = X / std
    R = (Xn @ Xn.T) / (Xn.shape[1] - 1)                    # (P, P), in ~[-1,1]

    # 3) CLS affinity map (grid,grid)
    CLS = tokens_B_N_C[:, 0:1, :].mean(dim=0)              # (1, C)
    CLSn = F.normalize(CLS, dim=1)                         # (1, C)
    sim = (CLSn @ Tn.T).squeeze(0)                         # (P,)
    CLSmap = sim.view(grid, grid)                          # (grid,grid)

    return G.cpu().numpy(), R.cpu().numpy(), CLSmap.cpu().numpy()

# ---- 4) Build a single HxWx3 frame + a preview tile ----
def build_frame_from_model(step, fixed_imgs, H=128, W=128):
    model.eval()
    with torch.no_grad():
        _ = model(fixed_imgs.to(DEVICE))     # fires token hook
        toks = token_outputs["tokens"]       # (B, N, C)
        G, R, CLSmap = panels_from_tokens(toks, GRID)  # np arrays

        panelA = np.array(_to_img(G,      H, W)) / 255.0   # (H,W)
        panelB = np.array(_to_img(R,      H, W)) / 255.0
        panelC = np.array(_to_img(CLSmap, H, W)) / 255.0

        frame = np.stack([panelA, panelB, panelC], axis=-1).astype("float32")  # HxWx3
        preview = _tile([_to_img(G, H, W), _to_img(R, H, W), _to_img(CLSmap, H, W)], cols=3)
    model.train()
    return frame, preview
