# SAEs with nnsight
Short demo on loading SAEs and caching feature activations with nnsight

- Refer to this repo for more info basic usage, stats, and downloads for Sam's pythia-70m SAEs: https://github.com/saprmarks/dictionary_learning
- They're also available on my gdrive: https://drive.google.com/drive/folders/14fPh8gf16bLGJ0MD1vAyy6bog1lr2dSf?usp=sharing
- Here's an example for using Sam's SAEs in a google colab: https://colab.research.google.com/drive/1C9RjyB8Tia4CY9UOZVX4o3MmxdDirrqD?usp=sharing

In [1]:
import torch as t
from nnsight import LanguageModel

import sys
sys.path.append('..')
from dictionary_learning import AutoEncoder

In [2]:
DEVICE = "cuda:0"
DEBUGGING = True
D_MODEL = 512
D_SAE = 32768
DICT_ID = 10
DICT_PATH = "/share/projects/dictionary_circuits/autoencoders"

if DEBUGGING:
    tracer_kwargs = {'validate' : True, 'scan' : True}
else:
    tracer_kwargs = {'validate' : False, 'scan' : False}

In [6]:
# Load model
model = LanguageModel(
    "EleutherAI/pythia-70m-deduped",
    device_map = DEVICE,
    dispatch = True,
)

# Load submodules and dictionaries
embed = model.gpt_neox.embed_in
attns = [layer.attention for layer in model.gpt_neox.layers]
mlps = [layer.mlp for layer in model.gpt_neox.layers]
resids = [layer for layer in model.gpt_neox.layers]
submodules = attns + mlps + resids + [embed]

dictionaries = {}
submodule_to_name = {}
ae = AutoEncoder(D_MODEL, D_SAE).to(DEVICE)
ae.load_state_dict(t.load(f'{DICT_PATH}/pythia-70m-deduped/embed/{DICT_ID}_{D_SAE}/ae.pt'))
dictionaries[embed] = ae
submodule_to_name[embed] = 'embed'
for i in range(len(model.gpt_neox.layers)):
    ae = AutoEncoder(D_MODEL, D_SAE).to(DEVICE)
    ae.load_state_dict(t.load(f'{DICT_PATH}/pythia-70m-deduped/attn_out_layer{i}/{DICT_ID}_{D_SAE}/ae.pt'))
    dictionaries[attns[i]] = ae
    submodule_to_name[attns[i]] = f'attn{i}'

    ae = AutoEncoder(D_MODEL, D_SAE).to(DEVICE)
    ae.load_state_dict(t.load(f'{DICT_PATH}/pythia-70m-deduped/mlp_out_layer{i}/{DICT_ID}_{D_SAE}/ae.pt'))
    dictionaries[mlps[i]] = ae
    submodule_to_name[mlps[i]] = f'mlp{i}'

    ae = AutoEncoder(D_MODEL, D_SAE).to(DEVICE)
    ae.load_state_dict(t.load(f'{DICT_PATH}/pythia-70m-deduped/resid_out_layer{i}/{DICT_ID}_{D_SAE}/ae.pt'))
    dictionaries[resids[i]] = ae
    submodule_to_name[resids[i]] = f'resid{i}'

name_to_submodule = {v: k for k, v in submodule_to_name.items()}

# Run through a test input to figure out which hidden states are tuples
is_tuple = {}
with model.trace("_"):
    for submodule in submodules:
        is_tuple[submodule] = type(submodule.output.shape) == tuple

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [7]:
# Cache feature activations with nnsight
prompt = "Simple input"

feature_activations = {}
reconstruction_errors = {}

with t.no_grad(), model.trace(prompt, **tracer_kwargs):
    for submodule in submodules:
        x = submodule.output
        if is_tuple[submodule]:
            x = x[0]

        x_hat, f = dictionaries[submodule].forward(x, output_features=True)
        x_err = x - x_hat
        
        feature_activations[submodule] = f.save()
        reconstruction_errors[submodule] = x_err.save()

In [9]:
# Inspect cached feature activation
name = "resid0"
feature_activations[name_to_submodule[name]].shape # (batch_size, seq_len, hidden_size)

torch.Size([1, 2, 32768])