- https://github.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb
- https://transformer-circuits.pub/2023/monosemantic-features
    - https://transformer-circuits.pub/2023/monosemantic-features#setup-interface

- A basic introduction to SAEs.
    - What is SAE Lens?
    - Choosing an SAE to analyse and loading it with SAE Lens.
    - The SAE Class and it's config.
- SAE Features.
    - What is a feature dashboard (仪表盘)?
    - Loading feature dashboards on Neuronpedia.
    - Downloading Autointerp and searching via explanations.
- Feature inference
    - Using the **HookedSAE** Transformer Class to **decompose activations into features**.
    - Comparing Features accross related prompts.
- Making Feature Dashboards
    - Max Activating Examples
    - Feature Activation Histograms
    - Logit Weight Distributions.
    - Extension: Reproducing `Not all language model features are linear`
- SAE based Analysis Methods (Advanced)
    - Steering model generation with SAE Features
    - Ablating SAE Features
    - Gradient-based Attribution for Circuit Detection

In [1]:
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x79920efde410>

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Loading a pretrained Sparse Autoencoder

In [4]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

- Joseph's Open Source GPT2 Small Residual (gpt2-small-res-jb)
- Joseph's Feature Splitting (gpt2-small-res-jb-feature-splitting)
- Gemma SAEs (gemma-2b-res-jb) (0,6) <- on Neuronpedia and good. (12 / 17 aren't very good currently).

In [9]:
df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
    columns=[
        "expected_var_explained",
        "expected_l0",
        "config_overrides",
        "conversion_func",
    ],
    inplace=True,
)

In [10]:
df.head()

Unnamed: 0,release,repo_id,model,saes_map,neuronpedia_id
gemma-2b-it-res-jb,gemma-2b-it-res-jb,jbloom/Gemma-2b-IT-Residual-Stream-SAEs,gemma-2b-it,{'blocks.12.hook_resid_post': 'gemma_2b_it_blo...,{'blocks.12.hook_resid_post': 'gemma-2b-it/12-...
gemma-2b-res-jb,gemma-2b-res-jb,jbloom/Gemma-2b-Residual-Stream-SAEs,gemma-2b,{'blocks.0.hook_resid_post': 'gemma_2b_blocks....,{'blocks.0.hook_resid_post': 'gemma-2b/0-res-j...
gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/average_l0_106': 'layer_...,"{'layer_10/width_131k/average_l0_106': None, '..."
gemma-scope-27b-pt-res-canonical,gemma-scope-27b-pt-res-canonical,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/canonical': 'layer_10/wi...,{'layer_10/width_131k/canonical': 'gemma-2-27b...
gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/average_l0_104': 'layer_0/...,"{'layer_0/width_16k/average_l0_104': None, 'la..."


In [17]:
print("SAEs in the GTP2 Small Resid Pre release")
for k, v in df.loc[df.release == "gpt2-small-res-jb", "saes_map"].values[0].items():
    print(f"SAE id: {k} for hook point: {v}")
print("-" * 50)
print("SAEs in the feature splitting release")
for k, v in (
    df.loc[df.release == "gpt2-small-res-jb-feature-splitting", "saes_map"]
    .values[0]
    .items()
):
    print(f"SAE id: {k} for hook point: {v}")
print("-" * 50)
print("SAEs in the Gemma base model release")
for k, v in df.loc[df.release == "gemma-2b-res-jb", "saes_map"].values[0].items():
    print(f"SAE id: {k} for hook point: {v}")

SAEs in the GTP2 Small Resid Pre release
SAE id: blocks.0.hook_resid_pre for hook point: blocks.0.hook_resid_pre
SAE id: blocks.1.hook_resid_pre for hook point: blocks.1.hook_resid_pre
SAE id: blocks.2.hook_resid_pre for hook point: blocks.2.hook_resid_pre
SAE id: blocks.3.hook_resid_pre for hook point: blocks.3.hook_resid_pre
SAE id: blocks.4.hook_resid_pre for hook point: blocks.4.hook_resid_pre
SAE id: blocks.5.hook_resid_pre for hook point: blocks.5.hook_resid_pre
SAE id: blocks.6.hook_resid_pre for hook point: blocks.6.hook_resid_pre
SAE id: blocks.7.hook_resid_pre for hook point: blocks.7.hook_resid_pre
SAE id: blocks.8.hook_resid_pre for hook point: blocks.8.hook_resid_pre
SAE id: blocks.9.hook_resid_pre for hook point: blocks.9.hook_resid_pre
SAE id: blocks.10.hook_resid_pre for hook point: blocks.10.hook_resid_pre
SAE id: blocks.11.hook_resid_pre for hook point: blocks.11.hook_resid_pre
SAE id: blocks.11.hook_resid_post for hook point: blocks.11.hook_resid_post
---------------

In [16]:
for i in range(8):
    print(i, 2**i, 2**i * 768)

0 1 768
1 2 1536
2 4 3072
3 8 6144
4 16 12288
5 32 24576
6 64 49152
7 128 98304


In [18]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",  # <- Release name
    sae_id="blocks.7.hook_resid_pre",  # <- SAE id (not always a hook point!)
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


blocks.7.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [64]:
model.tokenizer.vocab_size

50257

In [41]:
print(sae.cfg.__dict__)

{'architecture': 'standard', 'd_in': 768, 'd_sae': 24576, 'activation_fn_str': 'relu', 'apply_b_dec_to_input': True, 'finetuning_scaling_factor': False, 'context_size': 128, 'model_name': 'gpt2-small', 'hook_name': 'blocks.7.hook_resid_pre', 'hook_layer': 7, 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'Skylion007/openwebtext', 'dataset_trust_remote_code': True, 'normalize_activations': 'none', 'dtype': 'torch.float32', 'device': 'cuda', 'sae_lens_training_version': None, 'activation_fn_kwargs': {}, 'neuronpedia_id': 'gpt2-small/7-res-jb', 'model_from_pretrained_kwargs': {'center_writing_weights': True}, 'seqpos_slice': (None,)}


In [35]:
cfg_dict['dataset_path'], cfg_dict['prepend_bos']

('Skylion007/openwebtext', True)

In [38]:
cfg_dict['normalize_activations'], cfg_dict['dtype'], cfg_dict['device']

('none', 'torch.float32', 'cuda')

In [32]:
cfg_dict['hook_point'], cfg_dict['hook_point_layer'], cfg_dict['hook_point_head_index']

('blocks.7.hook_resid_pre', 7, None)

In [27]:
# 'd_in': 768,
# 'd_sae': 24576 (d_in * 32)
# 'activation_fn_str': relu
# apply_b_dec_to_input: dec_bias, true
# 'context_size': 128
24576 / 768

32.0

In [42]:
from datasets import load_dataset
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

README.md:   0%|          | 0.00/373 [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/921 [00:00<?, ?B/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (229134 > 1024). Running this sequence through the model will result in indexing errors


In [48]:
print(dataset[0]['text'][:100], dataset[0]['meta'])

It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi {'pile_set_name': 'Pile-CC'}


In [60]:
sae.cfg.prepend_bos, sae.cfg.context_size

(True, 128)

In [56]:
model.tokenizer.encode('<|endoftext|>')

[50256]

In [59]:
token_dataset[0]['tokens'].shape

torch.Size([128])

### What are SAE Features?

- An SAE feature represents **a pattern or concept** that the **autoencoder has learned to detect** in the input data. 
    - These features often correspond to **meaningful semantic, syntactic, or otherwise interpretable elements of text**, and correspond to **linear directions** in activation space.
-  SAEs are trained on the **activations** of a specific part of the model, and after training, these features show up as **activations in the hidden layer** of the SAE (which is **much wider** than the source activation vector, and produces one hidden activation per feature).
    -  As such, the hidden activations represent a decomposition of the entangled/superimposed features found in the original model activations.
    -  Ideally, these activations are **sparse**: Only a few of the many possible hidden activations actually activate **for a given piece of input**. This sparseness tends to correspond to ease of interpretability.

In [61]:
from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"


def get_dashboard_html(sae_release="gpt2-small", sae_id="7-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)


html = get_dashboard_html(
    sae_release="gpt2-small", sae_id="7-res-jb", feature_idx=feature_idx
)
IFrame(html, width=1200, height=600)