In [7]:
import numpy as np
from uuid import uuid4
import gc
import torch
import datasets
from datasets import load_dataset
from sae_lens import SAE
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
from sklearn.decomposition import PCA, SparsePCA

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load Dataset

In [3]:
ds = load_dataset("domenicrosati/TruthfulQA")
ds = ds.remove_columns(['Type', 'Category', 'Question', 'Best Answer', 'Source'])
data = ds

README.md: 0.00B [00:00, ?B/s]

train.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/817 [00:00<?, ? examples/s]

### Load SAE's

### Load Model

In [5]:
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", device=device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]



Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


In [8]:
TEXT_COL = "Correct Answers"
BATCH    = 16

model.eval()
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# (do this before we start adding SAE columns)
data = datasets.DatasetDict({
    name: split.remove_columns([c for c in split.column_names if c != TEXT_COL])
    for name, split in data.items()
})

for i in range(3, 28, 4):
    layer   = f"blocks.{i}.hook_resid_post"
    NEW_COL = f"correct_sae_{i}_acts"

    # Load SAE for this layer
    sae = SAE.from_pretrained(
        release="llama-3.1-8b-instruct-andyrdt",
        sae_id=f"resid_post_layer_{i}_trainer_1",
        device="cuda",
    ).eval()

    @torch.inference_mode()
    def token2sae(texts):                    
        toks = model.to_tokens(texts)
        _, cache = model.run_with_cache(
            toks,
            names_filter=lambda n: n == layer,  # cache only this hook
            stop_at_layer=i + 1,                # stop early
        )
        resid = cache[layer]                    # [B, L, d_model]
        reps  = sae.encode(resid)               # [B, L, d_sae]
        acts  = reps.mean(dim=1)                # [B, d_sae]
        out   = acts.float().cpu().numpy().tolist()

        del cache, resid, reps, acts
        return {NEW_COL: out}

    # Map PER SPLIT, while extending that split's existing schema with the new column
    new_splits = {}
    for name, split in data.items():
        feats = split.features.copy()                           # keep all existing cols
        feats[TEXT_COL] = datasets.Value("string")              # ensure text type is set
        feats[NEW_COL]  = datasets.Sequence(datasets.Value("float32"))  # add this layer's vector

        new_splits[name] = split.map(
            token2sae,
            batched=True,
            batch_size=BATCH,
            input_columns=[TEXT_COL],
            features=feats,
            writer_batch_size=32,
            load_from_cache_file=False,
            desc=f"SAE layer {i}",
            new_fingerprint=str(uuid4())
        )

    # Replace the whole DatasetDict with the updated splits
    data = datasets.DatasetDict(new_splits)

    # Free VRAM for next SAE
    del sae
    torch.cuda.empty_cache()


# features: ['Correct Answers', 'correct_sae_3_acts', 'correct_sae_7_acts', ..., 'correct_sae_27_acts']


SAE layer 3:   0%|          | 0/817 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/948 [00:00<?, ?B/s]

resid_post_layer_7/trainer_1/ae.pt:   0%|          | 0.00/4.30G [00:00<?, ?B/s]

SAE layer 7:   0%|          | 0/817 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/951 [00:00<?, ?B/s]

resid_post_layer_11/trainer_1/ae.pt:   0%|          | 0.00/4.30G [00:00<?, ?B/s]

SAE layer 11:   0%|          | 0/817 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/951 [00:00<?, ?B/s]

resid_post_layer_15/trainer_1/ae.pt:   0%|          | 0.00/4.30G [00:00<?, ?B/s]

SAE layer 15:   0%|          | 0/817 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/951 [00:00<?, ?B/s]

resid_post_layer_19/trainer_1/ae.pt:   0%|          | 0.00/4.30G [00:00<?, ?B/s]

SAE layer 19:   0%|          | 0/817 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/951 [00:00<?, ?B/s]

resid_post_layer_23/trainer_1/ae.pt:   0%|          | 0.00/4.30G [00:00<?, ?B/s]

SAE layer 23:   0%|          | 0/817 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/951 [00:00<?, ?B/s]

resid_post_layer_27/trainer_1/ae.pt:   0%|          | 0.00/4.30G [00:00<?, ?B/s]

SAE layer 27:   0%|          | 0/817 [00:00<?, ? examples/s]

#### Quick Sanity check
check if new features exist

In [9]:
len(data['train'][0]['correct_sae_3_acts'])

131072

In [10]:
repo_id = "mksethi/sae-acts-llama31-8b-it"

In [None]:
data.push_to_hub(
    repo_id,
    private=True,
    max_shard_size="2GB",
    commit_message="Add SAE pooled activations (layers 3..27)"
)