In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

use_mps = torch.backends.mps.is_available()
device_lm = torch.device("mps" if use_mps else "cpu")  # fallback CPU if needed

model_id = "Qwen/Qwen2.5-1.5B-Instruct"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# fp16 is the sweet spot on MPS; bf16 also works on newer PyTorch, but fp16 is safer today.
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype=torch.float16,
    low_cpu_mem_usage=True,
    attn_implementation="sdpa",  # MPS-friendly attention
)
model.to(device_lm)
model.eval()
for p in model.parameters():
    p.requires_grad_(False)

In [4]:
prompt = "Explain why the sky is blue in simple terms."
inp = tok(prompt, return_tensors="pt").to(device_lm)
out = model.generate(**inp, max_new_tokens=50)
print(tok.decode(out[0], skip_special_tokens=True))

Explain why the sky is blue in simple terms. The sky appears blue because of a phenomenon called Rayleigh scattering, which occurs when sunlight passes through Earth's atmosphere and interacts with the gases present in it. Blue light has shorter wavelengths than other colors of visible light, such as red or green. When


In [5]:
from datasets import load_dataset

# True streaming keeps RAM flat; you can also point this at your own text iterator.
ds = load_dataset("chrisociepa/wikipedia-pl-20230401", split="train", streaming=True)

max_len = 512    # start modest on MPS
stride  = 256

def token_blocks():
    for item in ds:
        ids = tok(item["text"], return_tensors="pt", truncation=False).input_ids[0]
        for s in range(0, len(ids), stride):
            chunk = ids[s:s+max_len]
            if len(chunk) > 1:
                yield chunk.unsqueeze(0)  # [1, T]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [6]:
from torch import Tensor

L = 12  # pick a mid layer first

act_buf: Tensor = None
def grab_hook(_m, _inp, out):
    # out: [B, T, d_model] on MPS; keep it fp16 and detached
    global act_buf
    act_buf = out.detach()

handle = model.model.layers[L].register_forward_hook(grab_hook)

In [7]:
import torch.nn as nn, torch.optim as optim

device_sae = device_lm  # run SAE on MPS too

class TopKSAE(nn.Module):
    def __init__(self, d_in, d_lat, k):
        super().__init__()
        self.enc = nn.Linear(d_in, d_lat, bias=True)
        self.dec = nn.Linear(d_lat, d_in, bias=True)
        self.dec.weight = nn.Parameter(self.enc.weight.T)  # tied weights
        self.k = k

    def forward(self, x):
        # x: [N, d_in] (fp16 on MPS)
        pre = self.enc(x)                             # [N, d_lat]
        vals, idx = torch.topk(pre, k=self.k, dim=-1)
        z = torch.zeros_like(pre)
        # ReLU only on selected values
        z.scatter_(dim=-1, index=idx, src=torch.relu(vals))
        xhat = self.dec(z)
        return xhat, z

In [8]:
from tqdm import tqdm

sae, opt = None, None
micro_bs = 4  # tiny microbatch fits MPS memory
accum = 8  # gradient accumulation
n_steps = 100
step = 0
k = 64  # active features/token
torch.set_float32_matmul_precision("medium")  # helps SDPA on Apple

pbar = tqdm(total=n_steps, desc="SAE training", initial=step)

for chunk in token_blocks():

    if step >= n_steps: break

    # ---- LM forward (cheap; no grad) ----
    with torch.no_grad():
        _ = model(input_ids=chunk.to(device_lm))  # fills act_buf via hook

    H = act_buf.squeeze(0)  # [T, d_model], still on MPS, fp16
    T, d = H.shape
    if sae is None:
        sae = TopKSAE(d_in=d, d_lat=8192, k=k).to(device_sae, dtype=torch.float16)
        opt = optim.AdamW(sae.parameters(), lr=3e-3)

    perm = torch.randperm(T, device=device_sae)
    running = 0
    for i in range(0, T, micro_bs):
        idx = perm[i:i + micro_bs]
        xb = H.index_select(0, idx)  # stays on MPS

        with torch.autocast(device_type="mps", dtype=torch.float16):
            xhat, z = sae(xb)
            recon = ((xhat - xb) ** 2).mean()

        (recon / accum).backward()
        running += 1
        if running % accum == 0:
            opt.step();
            opt.zero_grad(set_to_none=True)
            step += 1
            pbar.update(1)
            if step % 500 == 0:
                print(f"step {step}: recon={float(recon):.4f}")
            if step >= n_steps:
                break

handle.remove()
pbar.close()


SAE training: 100%|██████████| 100/100 [00:04<00:00, 21.23it/s]


In [9]:
from collections import defaultdict


@torch.no_grad()
def score_features(n_steps=3000, threshold=None):
    # Running sums; create lazily to stay lean
    sums, counts, maxes, fires = defaultdict(float), defaultdict(int), defaultdict(float), defaultdict(int)
    total_tokens = 0

    step = 0
    pbar = tqdm(total=n_steps, desc="Scoring features")
    for chunk in token_blocks():
        if step >= n_steps: break
        _ = model(input_ids=chunk.to(device_lm))  # fills act_buf via hook
        H = act_buf.squeeze(0)  # [T, d_model], fp16 on MPS
        T, d = H.shape

        # SAE encode (no reconstruction needed here)
        with torch.autocast(device_type="mps", dtype=torch.float16):
            _, Z = sae(H)  # [T, n_latent] (dense or sparse)

        # If your SAE is Top-K and returns dense z with zeros elsewhere, this is cheap.
        # Otherwise set a threshold τ to count "fires".
        if threshold is None:
            # auto τ as 90th percentile per-batch to keep firing rare:
            tau = torch.quantile(Z.detach().to(torch.float32).flatten(), 0.90).item()
        else:
            tau = threshold

        Z32 = Z.to(torch.float32)  # stable stats on fp32
        total_tokens += T

        # Reduce along tokens in small slices to avoid large temps
        bs = 2048
        for i in range(0, T, bs):
            z = Z32[i:i + bs]  # [b, n_latent]
            sums_batch = z.sum(0)  # [n_latent]
            max_batch = z.max(0).values
            fires_batch = (z > tau).sum(0)

            for j, val in enumerate(sums_batch):
                v = float(val);
                m = float(max_batch[j]);
                f = int(fires_batch[j])
                sums[j] += v
                counts[j] += z.shape[0]
                if m > maxes[j]: maxes[j] = m
                fires[j] += f

        step += 1
        pbar.update(1)

    pbar.close()

    # Aggregate stats
    stats = {}
    for k in sums.keys():
        stats[k] = {
            "mean": sums[k] / max(1, counts[k]),
            "max": maxes[k],
            "fire_rate": fires[k] / max(1, total_tokens)
        }
    return stats


stats = score_features(n_steps=3)  # ~quick skim


Scoring features: 100%|██████████| 3/3 [00:13<00:00,  4.64s/it]


In [15]:
import math

F = 200  # focus on 200 most promising features
# Sort by a combined score (max * sqrt(fire_rate))—balances salience & selectivity
top_feat = sorted(stats.keys(),
                  key=lambda k: stats[k]["max"] * math.sqrt(max(1e-8, stats[k]["fire_rate"])),
                  reverse=True)[:F]

top_feat

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,
