# Demo Notebook

Steps:
1. Download SAE with SAE Lens.
2. Create a dataset consistent with that SAE. 
3. Fold the SAE decoder norm weights so that feature activations are "correct".
4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.
5. Run the SAE generator for the features you want.

# Set Up

In [1]:
import torch
from sae_lens import SAE 
from transformer_lens import HookedTransformer
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner
from huggingface_hub import notebook_login

In [2]:
# notebook_login()

## Step 1. Download / Initialize SAE

In [3]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

model = HookedTransformer.from_pretrained("gpt2", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, _, _ = SAE.from_pretrained(
        release="gpt2-small-hook-z-kk",
        sae_id="blocks.5.hook_z",
        device=device
    )
# fold w_dec norm so feature activations are accurate
sae.fold_W_dec_norm()



Device: cuda




Loaded pretrained model gpt2 into HookedTransformer


  weights = torch.load(file_path, map_location=device)


# 2. Get token dataset

In [4]:
from sae_lens import ActivationsStore

activations_store = ActivationsStore.from_sae(
    model = model,
    sae = sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=8,
    device=device
)

# Some SAEs will require we estimate the activation norm and fold it into the weights. This is easy with SAE Lens. 
if sae.cfg.normalize_activations == "expected_average_only_in":
    norm_scaling_factor = activations_store.estimate_norm_scaling_factor(n_batches_for_norm_estimate=30)
    sae.fold_activation_norm_scaling_factor(norm_scaling_factor)



In [5]:
from tqdm import tqdm 

def get_tokens(
    activations_store: ActivationsStore,
    n_prompts: int,
):
    all_tokens_list = []
    pbar = tqdm(range(n_prompts))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens

# 1000 prompts is plenty for a demo.
token_dataset = get_tokens(activations_store, 256)
print(f"Token dataset size: {token_dataset.shape}")

  0%|          | 0/256 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 256/256 [00:02<00:00, 120.11it/s]

Token dataset size: torch.Size([4096, 128])





In [6]:
insert = model.to_tokens("Stalinists shriek in the ears of the police that")
token_dataset[0, :13] = insert[0]

# 4. Generate Feature Dashboards

In [7]:
import gc 
gc.collect()
torch.cuda.empty_cache()

In [8]:
from pathlib import Path
test_feature_idx_gpt = list(range(64))

feature_vis_config_gpt = SaeVisConfig(
    hook_point=sae.cfg.hook_name,
    features=test_feature_idx_gpt,
    minibatch_size_features=32,
    minibatch_size_tokens=256, # this is number of prompts at a time.
    verbose=True,
    device="cuda",
    #cache_dir=Path("demo_activations_cache"), # this will enable us to skip running the model for subsequent features.
    dtype="bfloat16",
    use_dfa=True,
)

data = SaeVisRunner(feature_vis_config_gpt).run(
    encoder=sae, # type: ignore
    model=model,
    tokens=token_dataset,
)

Forward passes to cache data for vis:   0%|          | 0/32 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/64 [00:00<?, ?it/s]

shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([4096, 128])
shape of feat_acts: torch.Size([

In [11]:
type(data.feature_data_dict[0].dfa_data[0])

dict

In [14]:
type(data.feature_data_dict[0].dfa_data[0]["dfaValues"][0])

float

## Test Neuronpedia Runner

In [1]:
import os

from sae_dashboard.neuronpedia.neuronpedia_runner import (
    NeuronpediaRunner,
    NeuronpediaRunnerConfig,
)

# python neuronpedia.py generate --sae-set=res-jb --sae-path=/opt/Gemma-2b-Residual-Stream-SAEs/gemma_2b_blocks.10.hook_resid_post_16384 --dataset-path=Skylion007/openwebtext --log-sparsity=-6 --dtype= --feat-per-batch=128 --n-prompts=24576 --n-context-tokens=128 --n-prompts-in-forward-pass=128 --resume-from-batch=0 --end-at-batch=-1

# LOCAL PATHS
NP_OUTPUT_FOLDER = "neuronpedia_outputs/"
ACT_CACHE_FOLDER = "cached_activations"

# NP SET NAME
NP_SET_NAME = "att-kk"
SAE_SET = "gpt2-small-hook-z-kk"
SAE_PATH = "blocks.5.hook_z"

# DATAEST
HF_DATASET_PATH = "Skylion007/openwebtext"


SPARSITY_THRESHOLD = 1

# IMPORTANT
SAE_DTYPE = "float32"
MODEL_DTYPE = "float32"

# PERFORMANCE SETTING
N_PROMPTS = 256
N_TOKENS_IN_PROMPT = 128
N_PROMPTS_IN_FORWARD_PASS = 128
NUM_FEATURES_PER_BATCH = 256

In [2]:
# delete output files if present
os.system(f"rm -rf {NP_OUTPUT_FOLDER}")
os.system(f"rm -rf {ACT_CACHE_FOLDER}")

# # we make two batches of 2 features each
cfg = NeuronpediaRunnerConfig(
    sae_set=SAE_SET,
    sae_path=SAE_PATH,
    np_set_name=NP_SET_NAME,
    from_local_sae=False,
    huggingface_dataset_path=HF_DATASET_PATH,
    sae_dtype=SAE_DTYPE,
    model_dtype=MODEL_DTYPE,
    outputs_dir=NP_OUTPUT_FOLDER,
    sparsity_threshold=SPARSITY_THRESHOLD,
    n_prompts_total=N_PROMPTS,
    n_tokens_in_prompt=N_TOKENS_IN_PROMPT,
    n_prompts_in_forward_pass=N_PROMPTS_IN_FORWARD_PASS,
    n_features_at_a_time=NUM_FEATURES_PER_BATCH,
    start_batch=0,
    end_batch=0,
    use_dfa=True,
    use_wandb=False,
    # TESTING ONLY
    # end_batch=6,
)

runner = NeuronpediaRunner(cfg)
runner.run()


  weights = torch.load(file_path, map_location=device)


Overriding sae dtype to float32
Device Count: 1
SAE Device: cuda
Model Device: cuda
Model Num Devices: None
Activation Store Device: cuda
Dataset Path: Skylion007/openwebtext
Forward Pass size: 128
Total number of tokens: 32768
Total number of contexts (prompts): 256
SAE Config on disk:
{
  "architecture": "standard",
  "d_in": 768,
  "d_sae": 49152,
  "dtype": "torch.float32",
  "device": "cuda",
  "model_name": "gpt2-small",
  "hook_name": "blocks.5.attn.hook_z",
  "hook_layer": 5,
  "hook_head_index": null,
  "activation_fn_str": "relu",
  "activation_fn_kwargs": {},
  "apply_b_dec_to_input": true,
  "finetuning_scaling_factor": false,
  "sae_lens_training_version": null,
  "prepend_bos": true,
  "dataset_path": "Skylion007/openwebtext",
  "dataset_trust_remote_code": true,
  "context_size": 128,
  "normalize_activations": "none"
}
SAE does not have from_pretrained_kwargs. Standard TransformerLens Loading
SAE DType: float32
Model DType: float32




Loaded pretrained model gpt2-small into HookedTransformer
Skipping sparsity because sparsity_threshold was set to 1
Tokens don't exist, making them.


  0%|          | 0/32 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 32/32 [00:00<00:00, 129.83it/s]
0it [00:00, ?it/s]

DFA flag set to True


Forward passes to cache data for vis:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/256 [00:00<?, ?it/s]

Processed batch of 256 sequences
Original indices bold: tensor([   19,    18,    16,    17,     1,     0,     2,     3,    11,    10,
            8,     9,    13,    12,    14,    15,     7,     6,     4,     5,
        30942,  6599, 23186, 28453, 20722])
Indices dict: {'TOP ACTIVATIONS<br>MAX = 0.000': tensor([[ 0, 19],
        [ 0, 18],
        [ 0, 16],
        [ 0, 17],
        [ 0,  1],
        [ 0,  0],
        [ 0,  2],
        [ 0,  3],
        [ 0, 11],
        [ 0, 10],
        [ 0,  8],
        [ 0,  9],
        [ 0, 13],
        [ 0, 12],
        [ 0, 14],
        [ 0, 15],
        [ 0,  7],
        [ 0,  6],
        [ 0,  4],
        [ 0,  5]]), 'INTERVAL 0.000 - 0.000<br>CONTAINS 100.000%': tensor([[241,  94],
        [ 51,  71],
        [181,  18],
        [222,  37],
        [161, 114]])}
Processed batch of 256 sequences
Original indices bold: tensor([   19,    18,    16,    17,     1,     0,     2,     3,    11,    10,
            8,     9,    13,    12,    14,    15, 

Saving feature-centric vis:   0%|          | 0/256 [00:00<?, ?it/s]

192it [00:47,  4.06it/s]

Output written to neuronpedia_outputs/gpt2-small_gpt2-small-hook-z-kk_blocks.5.attn.hook_z/batch-0.json





In [3]:
import json
import os
from typing import Any, Dict, List

import pytest

def load_json_file(file_path: str) -> Dict[str, Any]:
    with open(file_path, 'r') as f:
        return json.load(f)

In [4]:
output_dir = '/root/SAEDashboard/neuronpedia_outputs/gpt2-small_gpt2-small-hook-z-kk_blocks.5.attn.hook_z'
    
# Find all batch files
batch_files = [f for f in os.listdir(output_dir) if f.startswith('batch-') and f.endswith('.json')]

assert len(batch_files) > 0, "No batch files found in the output directory"

file_path = os.path.join(output_dir, batch_files[0])
batch_data = load_json_file(file_path)

In [8]:
batch_data["features"][0]["activations"][0].keys()

dict_keys(['bin_min', 'bin_max', 'bin_contains', 'tokens', 'values', 'dfa_values', 'dfa_maxValue', 'dfa_targetIndex'])

In [20]:
len(batch_data["features"][15]["activations"][0]['tokens'])

127

In [25]:
seq_idx = 6

tokens = batch_data["features"][15]["activations"][seq_idx]["tokens"]
dfa_values = batch_data["features"][15]["activations"][seq_idx]["dfa_values"]
print(f" Max DFA Value Index: {batch_data['features'][15]['activations'][seq_idx]['dfa_targetIndex']}")
for token, dfa_value in zip(tokens, dfa_values):
    print(f"Token: {token}: {round(dfa_value, 4)}")

 Max DFA Value Index: 19
Token: Port: 0.5154
Token: -: -0.0036
Token: au: -0.0009
Token: -: -0.0012
Token: Prince: -0.0006
Token: ,: -0.0005
Token:  Haiti: 0.0011
Token:  (: -0.005
Token: CNN: -0.0016
Token: ): 0.0024
Token:  --: -0.0002
Token:  Earthquake: -0.0003
Token:  victims: -0.0227
Token: ,: 0.0186
Token:  wr: 0.0609
Token: ithing: -0.0221
Token:  in: 0.4329
Token:  pain: 0.5258
Token:  and: 0.525
Token:  grasping: 0.1648
Token:  at: 0.0
Token:  life: 0.0
Token: ,: 0.0
Token:  watched: 0.0
Token:  doctors: 0.0
Token:  and: 0.0
Token:  nurses: 0.0
Token:  walk: 0.0
Token:  away: 0.0
Token:  from: 0.0
Token:  a: 0.0
Token:  field: 0.0
Token:  hospital: 0.0
Token:  Friday: 0.0
Token:  night: 0.0
Token:  after: 0.0
Token:  a: 0.0
Token:  Belgian: 0.0
Token:  medical: 0.0
Token:  team: 0.0
Token:  evacuated: 0.0
Token:  the: 0.0
Token:  area: 0.0
Token: ,: 0.0
Token:  saying: 0.0
Token:  it: 0.0
Token:  was: 0.0
Token:  concerned: 0.0
Token:  about: 0.0
Token:  security: 0.0
Token: 