In [17]:
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer
from sae_lens import SAE
from transformer_lens import HookedTransformer
import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32

### Load in SAE

In [9]:
sae, cfg, _ = SAE.from_pretrained(
    release="gemma-2b-it-res-jb",
    sae_id="blocks.12.hook_resid_post",
    device = "cpu"
)
sae = sae.to(device)

In [13]:
hook = 'blocks.12.hook_resid_post'

### Load in Hooked Model

In [None]:
model = HookedTransformer.from_pretrained("google/gemma-2b-it", device=device, dtype=dtype)

#### Load in Tokenizer

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [18]:
def tokenize_batch(batch):
    return tokenizer(
        batch["input"],
        truncation=True,
        max_length=256,
        padding=False
    )

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [None]:
eos_id = getattr(model.tokenizer, "eos_token_id", None)

### Load in Data

dataset = load_dataset("facebook/kilt_tasks", "eli5", split="train")


In [None]:
ds_tok = dataset.map(
    tokenize_batch,
    batched=True,
    remove_columns = ["id", "meta"],
    desc="Tokenizing prompts"
)

dang

In [None]:
@torch.no_grad()
def seq2vec(text: str):
    toks = model.to_tokens(text, prepend_bos=True).to(device)
    _, cache = model.run_with_cache(toks, names_filter=hook)
    acts = cache[hook].to(sae.device, dtype=sae.dtype).detach()
    feats = sae.encode(acts)
    v = feats.mean(dim=1)
    return v

In [None]:
@torch.inference_mode
def add_sae(batch):
    feats = []
    for ans in batch["answer_text"]:
        if not ans:
            feats.append(None)
        v = seq2vec(ans)
        feats.append(v.squeeze(0).to(torch.float16).cpu().tolist())
    return {"sae_acts": feats}
    
    

In [None]:
# 1) Extract a clean answer_text column
def extract_answer_text(batch):
    texts = []
    for o in batch["output"]:
        s = None
        if isinstance(o, list) and o:
            if isinstance(o[0], dict):
                s = o[0].get("answer") or o[0].get("text") or ""
            elif isinstance(o[0], str):
                s = o[0]
        elif isinstance(o, dict):
            s = o.get("answer") or o.get("text") or ""
        elif isinstance(o, str):
            s = o
        texts.append((s or "").strip())
    return {"answer_text": texts}

ds_norm = ds_tok.map(extract_answer_text, batched=True, desc="Extract answer text")
print(ds_norm[0]["answer_text"])  # sanity check: should be a plain str


In [None]:
ds_sae = ds_norm.map(
    add_sae,
    batched=True,
    batch_size=512,
    desc="Creating SAE dataset!"
)