In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("/root/sae-auto-interp")

from nnsight import LanguageModel
from functools import partial

from sae_auto_interp.features import FeatureDataset, FeatureCache, pool_max_activation_windows, sample
from sae_auto_interp.config import FeatureConfig, ExperimentConfig
from sae_auto_interp.get_activations import get_activations
from sae_auto_interp.utils import load_tokenized_data, display
from sae_auto_interp.clients import OpenRouter
from sae_auto_interp.explainers import SimpleExplainer

In [33]:
CTX_LEN = 128
BATCH_SIZE = 32
N_TOKENS = 1_000_000
MODEL_NAME = "google/gemma-2b-it"
DATASET_NAME = "jacobcd52/college_math_cleaned"
DATASET_SPLIT = "train"
FEATURE_IDX_LIST = list(range(100))
SAE_REPO = "jacobcd52/gemma-2b-it-ssae-college_math_cleaned"
SAE_CFG_FILE = "gemma-2b-it_layer12_college_math_cleaned_l1=10_expansion=2_tokens=8192000_gsae_id=layer_12_stepan_cfg.json"
SAE_WEIGHTS_FILE = "gemma-2b-it_layer12_college_math_cleaned_l1=10_expansion=2_tokens=8192000_gsae_id=layer_12_stepan.safetensors" 

In [34]:
# Run model to get SAE feature activations
model, sae_width = get_activations(sae_repo = SAE_REPO,
                    sae_weights_file = SAE_WEIGHTS_FILE,
                    sae_cfg_file = SAE_CFG_FILE,
                    feature_idx_list = FEATURE_IDX_LIST,
                    dataset_name  = DATASET_NAME,
                    dataset_split = DATASET_SPLIT,
                    model_name = MODEL_NAME,
                    batch_size = BATCH_SIZE,
                    ctx_len = CTX_LEN,
                    n_tokens = N_TOKENS,
                    )

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

(…)8192000_gsae_id=layer_12_stepan_cfg.json:   0%|          | 0.00/2.75k [00:00<?, ?B/s]

module path .model.layers.12
dict_keys(['.model.layers.12'])


Caching features:   0%|          | 0/244 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Caching features: 100%|██████████| 244/244 [05:11<00:00,  1.28s/it, Total Tokens=999,424]


Total tokens processed: 999,424
saving split at  /root/sae-auto-interp/splits/.model.layers.12
saving split at  /root/sae-auto-interp/splits/.model.layers.12
saving split at  /root/sae-auto-interp/splits/.model.layers.12
saving split at  /root/sae-auto-interp/splits/.model.layers.12
saving split at  /root/sae-auto-interp/splits/.model.layers.12


In [35]:
cfg = FeatureConfig(
    width = sae_width,
    min_examples = 200,
    max_examples = 10_000,
    example_ctx_len = 40,
    n_splits = 5
)

experiment_cfg = ExperimentConfig(n_quantiles=2) # TODO change?

feature_dataset = FeatureDataset(
    raw_dir="/root/sae-auto-interp/splits",
    cfg=cfg,
)

tokens = load_tokenized_data(
    CTX_LEN,
    model.tokenizer,
    DATASET_NAME,
    DATASET_SPLIT)

constructor=partial(
    pool_max_activation_windows,
    tokens=tokens,
    ctx_len=cfg.example_ctx_len,
    max_examples=cfg.max_examples,
)

sampler = partial(
    sample,
    cfg=experiment_cfg
)

loaded_data_iter = iter(feature_dataset.load(constructor=constructor, sampler=sampler))
records = next(loaded_data_iter)

print("length of records", len(records))
print("first feature:", records[0].feature)
display(records[0], model.tokenizer, n=4)

Loading .model.layers.12: 814it [01:05, 12.36it/s]


length of records 623
first feature: .model.layers.12_feature0


In [41]:
client = OpenRouter('anthropic/claude-3.5-sonnet', api_key="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")
explainer = SimpleExplainer(
    client,
    model.tokenizer,
    max_new_tokens=50,
    temperature=0.0,
)

explainer_result = await explainer(records[3])
display(records[3], model.tokenizer)
print(explainer_result.explanation)

Neuron activates at the beginning of sentences in mathematical or scientific texts, potentially helping to structure technical content.


In [42]:
display(records[5], model.tokenizer, n=10)