In [25]:
%pip install torch torchvision timm pillow numpy

Note: you may need to restart the kernel to use updated packages.


In [14]:
import warnings
warnings.filterwarnings("ignore")
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)

In [26]:
import sys 
from pathlib import Path

# Get the absolute path to the base directory (one level up from notebook dir)
BASE_DIR = Path(__file__).resolve().parent.parent if "__file__" in globals() else Path.cwd().parent

# Add to sys.path if not already present
if str(BASE_DIR) not in sys.path:
    sys.path.insert(0, str(BASE_DIR))

# Now you can import directly
from zeromodel import ZeroModel

In [16]:
import math, time, functools, numpy as np
from collections import deque
from PIL import Image, ImageDraw
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchvision.datasets as dsets
import timm

## CIFAR-10 dataloaders (fixed eval batch)

## ViT model + hooks to capture attention & token embeddings

In [27]:
# notebook_emitter.py
import os, json, time, numpy as np

class ZMFileEmitter:
    def __init__(self, run_dir="artifacts/zm_run1", fps=8):
        self.run_dir = run_dir
        os.makedirs(f"{run_dir}/frames", exist_ok=True)
        self.meta_f = open(f"{run_dir}/meta.jsonl", "a", buffering=1)
        # announce run (ZeroModel tailer watches for run.json)
        with open(f"{run_dir}/run.json","w") as f:
            json.dump({"status":"open","fps":fps}, f)

    def send(self, step:int, frame_hwC:np.ndarray, tags:dict=None):
        # frame: float32 [H,W,C] in [0,1]
        path = f"{self.run_dir}/frames/{step:06d}.npy"
        np.save(path, frame_hwC.astype("float32"))
        rec = {"step": step, "path": path, "tags": tags or {}, "ts": time.time()}
        self.meta_f.write(json.dumps(rec) + "\n")

    def close(self):
        self.meta_f.flush(); self.meta_f.close()
        with open(f"{self.run_dir}/run.json","w") as f:
            json.dump({"status":"closed"}, f)


In [28]:
# =========================
# 0) Imports & Utilities
# =========================
import os, json, math, time, numpy as np
from collections import deque
from contextlib import contextmanager
from PIL import Image, ImageDraw

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchvision.datasets as dsets
import timm
from timm.models.vision_transformer import Attention

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# --- Simple local GIF logger (optional preview) ---
class GifLogger:
    def __init__(self, path="cifar_vit.gif", fps=8):
        self.frames=[]; self.path=path; self.duration=int(1000/fps)
    def add(self, pil_img, text=None):
        im = pil_img.convert("P", palette=Image.ADAPTIVE, colors=256)
        if text:
            draw = ImageDraw.Draw(im); draw.text((4,4), text, fill=255)
        self.frames.append(im)
    def save(self):
        if self.frames:
            os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
            self.frames[0].save(self.path, save_all=True, append_images=self.frames[1:],
                                optimize=True, duration=self.duration, loop=0)

# --- ZeroModel file-drop emitter (frames + metadata) ---
class ZMFileEmitter:
    def __init__(self, run_dir="artifacts/zm_run1", fps=8):
        self.run_dir = run_dir
        os.makedirs(f"{run_dir}/frames", exist_ok=True)
        self.meta_path = f"{run_dir}/meta.jsonl"
        self.meta_f = open(self.meta_path, "a", buffering=1)
        with open(f"{run_dir}/run.json","w") as f:
            json.dump({"status":"open","fps":fps}, f)
    def send(self, step:int, frame_hwC:np.ndarray, tags:dict=None):
        path = f"{self.run_dir}/frames/{step:06d}.npy"
        np.save(path, frame_hwC.astype("float32"))
        rec = {"step": step, "path": path, "tags": tags or {}, "ts": time.time()}
        self.meta_f.write(json.dumps(rec) + "\n")
    def close(self):
        self.meta_f.flush(); self.meta_f.close()
        with open(f"{self.run_dir}/run.json","w") as f:
            json.dump({"status":"closed"}, f)

# --- image helper ---
def to_img(m, H=128, W=128):
    m = m.astype(np.float32)
    m = (m - m.min()) / (m.max() - m.min() + 1e-8)
    im = Image.fromarray((m*255).astype(np.uint8))
    return im.resize((W, H), resample=Image.BILINEAR)

def tile(imgs, cols=3):
    w, h = imgs[0].size
    rows = (len(imgs)+cols-1)//cols
    canvas = Image.new("L", (cols*w, rows*h))
    for i, im in enumerate(imgs):
        r, c = divmod(i, cols)
        canvas.paste(im, (c*w, r*h))
    return canvas


Device: cuda


In [29]:
def token_gram(tokens):  # tokens: (B,N,C)
    # use fixed eval batch → take mean over batch
    t = tokens.mean(dim=0)         # (N,C)
    t = F.normalize(t, dim=1)
    G = t @ t.T                    # (N,N), values in [-1,1]
    return G

def to_img(m, H=128, W=128):
    m = m.astype(np.float32)
    m = (m - m.min()) / (m.max() - m.min() + 1e-8)
    im = Image.fromarray((m*255).astype(np.uint8))
    return im.resize((W, H), resample=Image.BILINEAR)

def tile(imgs, cols=2):
    if not imgs: return Image.new("L",(1,1))
    w,h = imgs[0].size
    rows = (len(imgs)+cols-1)//cols
    canvas = Image.new("L",(cols*w, rows*h))
    for i,im in enumerate(imgs):
        r,c = divmod(i,cols)
        canvas.paste(im,(c*w, r*h))
    return canvas

class GifLogger:
    def __init__(self, path="cifar_vit.gif", fps=8):
        self.frames=[]; self.path=path; self.duration=int(1000/fps)
    def add(self, pil_img, text=None):
        im = pil_img.convert("P", palette=Image.ADAPTIVE, colors=256)
        if text:
            draw = ImageDraw.Draw(im); draw.text((4,4), text, fill=255)
        self.frames.append(im)
    def save(self):
        if self.frames:
            self.frames[0].save(self.path, save_all=True, append_images=self.frames[1:],
                                optimize=True, duration=self.duration, loop=0)


In [30]:
# =========================
# 1) Data: CIFAR-10 @ 224
# =========================
BATCH = 128

tfm_train = T.Compose([
    T.Resize(224),
    T.RandomCrop(224, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor()
])
tfm_eval  = T.Compose([
    T.Resize(224),
    T.ToTensor()
])

train_ds = dsets.CIFAR10(root="./data", train=True,  download=True, transform=tfm_train)
test_ds  = dsets.CIFAR10(root="./data", train=False, download=True, transform=tfm_eval)

train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=2, drop_last=True)
eval_dl  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=2, drop_last=True)

fixed_eval_imgs, fixed_eval_labels = next(iter(eval_dl))
print("Fixed eval batch:", fixed_eval_imgs.shape)


Fixed eval batch: torch.Size([128, 3, 224, 224])


In [31]:
# =========================
# 2) Model: ViT tiny 224
# =========================
model = timm.create_model("vit_tiny_patch16_224", pretrained=False, num_classes=10)
model = model.to(DEVICE)
print("Grid size (patch grid):", model.patch_embed.grid_size)  # (14, 14)
GRID = int(model.patch_embed.grid_size[0])  # should be 14
assert GRID == model.patch_embed.grid_size[1]


Grid size (patch grid): (14, 14)


In [32]:
# --- Drop-in replacement: attention wrapper WITH device/dtype sync & attn_mask support ---
import torch
import torch.nn as nn
from contextlib import contextmanager
from timm.models.vision_transformer import Attention

class AttentionWithWeights(Attention):
    """Timm-compatible attention that records probs as (B, Nq, H, Nk) in self.last_attn."""
    def forward(self, x, attn_mask=None):
        B, N, C = x.shape
        # qkv: (B, N, 3, heads, head_dim) -> (3, B, heads, N, head_dim)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # each: (B, heads, N, head_dim)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        if attn_mask is not None:
            attn = attn + attn_mask  # broadcast if provided
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # Store as (B, Nq, H, Nk) to match cls_to_patch_attn / mean_patch_to_patch
        self.last_attn = attn.permute(0, 2, 1, 3).detach()  # (B, N, heads, N)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

@contextmanager
def patch_vit_attn(vit_model):
    """
    Temporarily replace all blocks[i].attn with AttentionWithWeights,
    copying weights & moving to the SAME device/dtype as the original.
    """
    originals = []
    try:
        for blk in vit_model.blocks:
            a: Attention = blk.attn
            w = AttentionWithWeights(
                dim=a.qkv.in_features,
                num_heads=a.num_heads,
                qkv_bias=getattr(a, 'qkv_bias', True),
                qk_norm=getattr(a, 'qk_norm', False),
                attn_drop=getattr(a.attn_drop, 'p', 0.0),
                proj_drop=getattr(a.proj_drop, 'p', 0.0),
                norm_layer=getattr(a, 'norm_layer', None),
                attn_head_dim=getattr(a, 'attn_head_dim', None),
            )
            # copy state
            w.qkv.load_state_dict(a.qkv.state_dict())
            w.proj.load_state_dict(a.proj.state_dict())
            if hasattr(a, 'q_norm') and hasattr(w, 'q_norm'):
                w.q_norm.load_state_dict(a.q_norm.state_dict())
            if hasattr(a, 'k_norm') and hasattr(w, 'k_norm'):
                w.k_norm.load_state_dict(a.k_norm.state_dict())

            # >>> MOVE to SAME device & dtype as original <<<
            dev = next(a.parameters()).device
            dtype = next(a.parameters()).dtype
            w.to(device=dev, dtype=dtype)

            originals.append(a)
            blk.attn = w
        yield
    finally:
        # restore originals
        for blk, a in zip(vit_model.blocks, originals):
            blk.attn = a


In [None]:
# =========================
# 4) Log a structured frame
# =========================
def build_frame_from_model(step, fixed_imgs, H=128, W=128):
    model.eval()
    with torch.no_grad(), patch_vit_attn(model):
        _ = model(fixed_imgs.to(DEVICE))  # runs with capturing attention

        # choose a mid block for visualization (or last if fewer blocks)
        blk = model.blocks[3] if len(model.blocks) > 3 else model.blocks[-1]
        A = blk.attn.last_attn.detach().cpu()  # (B, Nq, H, Nk)

        # Panels:
        panelA = np.array(to_img(I don't deal with this like I just I don't deal with this like I just (A, GRID)[0].numpy(), H, W)) / 255.0
        panelB = np.array(to_img(mean_patch_to_patch(A)[0].numpy(), H, W)) / 255.0

        # ensure token hook fired
        if "tokens" not in token_outputs:
            _ = model(fixed_imgs.to(DEVICE))
        toks = token_outputs["tokens"]                     # (B, N, C)
        panelC = np.array(to_img(token_gram_from_tokens(toks).cpu().numpy(), H, W)) / 255.0

        # stack into HxWx3 frame for ZeroModel, and a grayscale tile grid for preview
        frame = np.stack([panelA, panelB, panelC], axis=-1).astype("float32")  # HxWx3
        preview = tile([to_img(panelA), to_img(panelB), to_img(panelC)], cols=3)
    model.train()
    return frame, preview


: 

In [24]:
# =========================
# 5) Train + Stream to ZeroModel
# =========================
opt = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)
loss_fn = nn.CrossEntropyLoss()

gif = GifLogger("artifacts/cifar_vit_preview.gif", fps=8)   # local preview gif
em  = ZMFileEmitter("artifacts/zm_run1", fps=8)             # ZeroModel stream

LOG_EVERY = 25
STEPS = 500

model.train()
it = iter(train_dl)

for step in range(1, STEPS+1):
    try:
        x, y = next(it)
    except StopIteration:
        it = iter(train_dl); x, y = next(it)
    x, y = x.to(DEVICE), y.to(DEVICE)

    opt.zero_grad()
    logits = model(x)
    loss = loss_fn(logits, y)
    loss.backward()
    opt.step()

    if step % LOG_EVERY == 0:
        frame_hwC, panel_preview = build_frame_from_model(step, fixed_eval_imgs)
        # 1) stream to ZeroModel (H x W x C floats in [0,1])
        em.send(step, frame_hwC, tags={"loss": float(loss.item())})
        # 2) add a tile to local preview GIF
        gif.add(panel_preview, text=f"step {step}  loss {loss.item():.3f}")

em.close()
gif.save()
print("Done. Preview GIF:", "artifacts/cifar_vit_preview.gif")
print("ZeroModel run dir:", "artifacts/zm_run1")


RuntimeError: Expected all tensors to be on the same device, but got mat1 is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA_addmm)