# Install

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!git lfs install
!git clone https://huggingface.co/cybershiptrooper/InterpBench

Updated git hooks.
Git LFS initialized.
Cloning into 'InterpBench'...
remote: Enumerating objects: 225, done.[K
remote: Counting objects: 100% (221/221), done.[K
remote: Compressing objects: 100% (205/205), done.[K
remote: Total 225 (delta 71), reused 0 (delta 0), pack-reused 4 (from 1)[K
Receiving objects: 100% (225/225), 385.53 KiB | 3.04 MiB/s, done.
Resolving deltas: 100% (71/71), done.
Filtering content: 100% (55/55), 82.13 MiB | 35.62 MiB/s, done.


# Imports and setup

In [3]:
import pickle
import torch
import os
import shutil
from transformer_lens import HookedTransformerConfig, HookedTransformer
from transformer_lens import HookedTransformer
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer

from tokenizers import Tokenizer, models, normalizers, pre_tokenizers, decoders, trainers
from transformers import PreTrainedTokenizerFast

from functools import partial
import ipywidgets as widgets
from IPython.display import display
import pandas as pd

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
from sae_lens import SAEConfig, SAE, TrainingSAEConfig, TrainingSAE, ActivationsStore, CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig
from sae_lens.training.sae_trainer import SAETrainer
from sae_utils import make_gated_sae_lens_config, train_sae

import sae_lens
print(dir(sae_lens.training.sae_trainer))

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

['ActivationsStore', 'Adam', 'Any', 'FINETUNING_PARAMETERS', 'HookedRootModule', 'L1Scheduler', 'LanguageModelSAERunnerConfig', 'SAETrainer', 'TrainSAEOutput', 'TrainStepOutput', 'TrainingSAE', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', '__version__', '_log_feature_sparsity', '_update_sae_lens_training_version', 'cast', 'contextlib', 'dataclass', 'get_lr_scheduler', 'run_evals', 'torch', 'tqdm', 'wandb']
Using device: cuda


# Load and configure benchmark

In [5]:
import circuits_benchmark.benchmark.cases.case_3 as case_3

task = case_3.Case3()
hl_model = task.get_hl_model()

ModuleNotFoundError: No module named 'iit.model_pairs.ll_model'

In [6]:
from circuits_benchmark.transformers.tracr_circuits_builder import build_tracr_circuits
tracr_output = task.get_tracr_output()
tracr_circuits = build_tracr_circuits(tracr_output.graph, tracr_output.craft_model, granularity="acdc_hooks")
for k, item in tracr_circuits.alignment.hl_to_ll_mapping.items():
    print(f'Operation {k} is handled in {item}')

Operation tokens is handled in {hook_embed}
Operation indices is handled in {blocks.1.attn.hook_k[0], blocks.1.attn.hook_q[0], hook_pos_embed}
Operation is_x_3 is handled in {blocks.0.hook_mlp_in, blocks.0.hook_mlp_out}
Operation frac_prevs_1 is handled in {blocks.1.hook_v_input[0], blocks.1.attn.hook_v[0], blocks.1.attn.hook_result[0]}
Operation frac_prevs_1 is handled in {blocks.1.hook_resid_post}
Operation select_2 is handled in {blocks.1.attn.hook_q[0], blocks.1.attn.hook_k[0]}


In [7]:
task_idx = 3
dir_name = f"InterpBench/{task_idx}"
cfg_dict = pickle.load(open(f"{dir_name}/ll_model_cfg.pkl", "rb"))
cfg = HookedTransformerConfig.from_dict(cfg_dict)
model = HookedTransformer(cfg)
weights = torch.load(f"{dir_name}/ll_model.pth")
model.load_state_dict(weights)

<All keys matched successfully>

In [8]:
total_data = task.get_total_data_len()
input_data, output_data = task.gen_all_data(task.get_min_seq_len(), task.get_max_seq_len())
print(len(input_data), len(output_data))
print(input_data[:4])
print(output_data[:4])


320 320
[['BOS', 'a', 'a', 'a', 'PAD'], ['BOS', 'a', 'a', 'b', 'PAD'], ['BOS', 'a', 'a', 'c', 'PAD'], ['BOS', 'a', 'a', 'x', 'PAD']]
[['BOS', 0.0, 0.0, 0.0, 'PAD'], ['BOS', 0.0, 0.0, 0.0, 'PAD'], ['BOS', 0.0, 0.0, 0.0, 'PAD'], ['BOS', 0.0, 0.0, 0.3333333333333333, 'PAD']]


In [9]:
# load high level model
# import circuits_benchmark.utils.iit.correspondence as correspondence
# from circuits_benchmark.utils.iit.dataset import get_unique_data
# import iit.model_pairs as mp


In [10]:
from datasets import Dataset
# Create dataset of case inputs
# dataset = get_unique_data(task, max_len=10_000)
tokenized_data = hl_model.map_tracr_input_to_tl_input(input_data)
print(tokenized_data)

# Convert PyTorch tensors to lists
string_tokens_list = input_data
tokens_list = tokenized_data.tolist()
labels_list = [str(label) for label in output_data]

# Create a dictionary from the lists
data_dict = {
    "string_tokens": string_tokens_list,
    "tokens": tokens_list,
    "labels": labels_list
}

# Create a Hugging Face dataset
hf_dataset = Dataset.from_dict(data_dict)

print(hf_dataset)
print(hf_dataset[3])

tensor([[0, 2, 2, 2, 1],
        [0, 2, 2, 3, 1],
        [0, 2, 2, 4, 1],
        ...,
        [0, 5, 5, 5, 3],
        [0, 5, 5, 5, 4],
        [0, 5, 5, 5, 5]])
Dataset({
    features: ['string_tokens', 'tokens', 'labels'],
    num_rows: 320
})
{'string_tokens': ['BOS', 'a', 'a', 'x', 'PAD'], 'tokens': [0, 2, 2, 5, 1], 'labels': "['BOS', 0.0, 0.0, 0.3333333333333333, 'PAD']"}


In [11]:
print(hf_dataset.shape)

(320, 3)


In [12]:
# create tokenizer
# Define your simple vocabulary
vocab = {'BOS': 0, 'UNK': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
# comes from task.get_vocab() and hl_model.map_tracr_input_to_tl_input

# Create a Tokenizer with a WordLevel model
tokenizer = Tokenizer(models.WordLevel(vocab=vocab, unk_token="UNK"))

# Set the normalizer, pre-tokenizer, and decoder
tokenizer.normalizer = normalizers.Sequence([normalizers.Lowercase(), normalizers.StripAccents()])
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

# Convert to Hugging Face tokenizer
hf_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)

# Add the special tokens to the Hugging Face tokenizer
hf_tokenizer.add_special_tokens({
    'unk_token': 'UNK',
    'bos_token': 'BOS',
    'cls_token': '[CLS]',
    'sep_token': '[SEP]',
    'pad_token': '[PAD]',
    'mask_token': '[MASK]'
})

# Test the tokenizer
encoded = hf_tokenizer.encode("BOS a b c x")
decoded = hf_tokenizer.decode(encoded)
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

Encoded: [0, 2, 3, 4, 5]
Decoded: BOS a b c x


In [13]:
model.tokenizer = hf_tokenizer #attach to model.

In [14]:
_, cache = model.run_with_cache(tokenized_data)
output = hl_model(tokenized_data) #TODO: why are these different calls?
print(output[:5], output_data[:5])

tensor([[[0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.3333],
         [0.2500]],

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000]]], device='cuda:0', grad_fn=<SliceBackward0>) [['BOS', 0.0, 0.0, 0.0, 'PAD'], ['BOS', 0.0, 0.0, 0.0, 'PAD'], ['BOS', 0.0, 0.0, 0.0, 'PAD'], ['BOS', 0.0, 0.0, 0.3333333333333333, 'PAD'], ['BOS', 0.0, 0.0, 0.0, 'PAD']]


In [15]:
print(cache)

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.hook_mlp_in', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post

# SAE-lens

In [16]:
class RepeatActivationsStore(ActivationsStore):

    def get_batch_tokens(self, batch_size: int | None = None):
        """
        Streams a batch of tokens from a dataset.
        """
        if not batch_size:
            batch_size = self.store_batch_size_prompts
        sequences = []
        # the sequences iterator yields fully formed tokens of size context_size, so we just need to cat these into a batch
        for _ in range(batch_size):
            try:
                sequences.append(next(self.iterable_sequences))
            except StopIteration:
                #shuffle self.dataset and restart
                self.iterable_sequences = self._iterate_tokenized_sequences()
                sequences.append(next(self.iterable_sequences))
                # self.iterable_dataset = iter(self.dataset)
                # s = next(self.iterable_dataset)[self.tokens_column]
            
        return torch.stack(sequences, dim=0).to(self.model.W_E.device)
    
    def _get_next_dataset_tokens(self) -> torch.Tensor:
        device = self.device
        if not self.is_dataset_tokenized:
            try:
                s = next(self.iterable_dataset)[self.tokens_column]
            except StopIteration:
                #shuffle self.dataset and restart
                self.iterable_dataset = iter(self.dataset)
                s = next(self.iterable_dataset)[self.tokens_column]
            tokens = (
                self.model.to_tokens(
                    s,
                    truncate=False,
                    move_to_device=True,
                    prepend_bos=self.prepend_bos,
                )
                .squeeze(0)
                .to(device)
            )
            assert (
                len(tokens.shape) == 1
            ), f"tokens.shape should be 1D but was {tokens.shape}"
        else:
            try:
                s = next(self.iterable_dataset)[self.tokens_column]
            except StopIteration:
                #shuffle self.dataset and restart
                self.iterable_dataset = iter(self.dataset)
                s = next(self.iterable_dataset)[self.tokens_column]
            tokens = torch.tensor(
                s,
                dtype=torch.long,
                device=device,
                requires_grad=False,
            )
            if (
                not self.prepend_bos
                and tokens[0] == self.model.tokenizer.bos_token_id  # type: ignore
            ):
                tokens = tokens[1:]
        self.n_dataset_processed += 1
        return tokens

## Residual stream 0 -- tokens and position

In [17]:

def make_sae_lens_config(hook_name: str, hook_layer: int, l1_coeff: float, training_tokens: int = 1_500_000):
    return LanguageModelSAERunnerConfig(
        # Data Generating Function (Model + Training Distribution)
        model_name = "case3",
        model_class_name = "HookedTransformer",
        hook_name = hook_name,
        hook_eval = "NOT_IN_USE",
        hook_layer = hook_layer,
        hook_head_index = None,
        dataset_path = "",
        dataset_trust_remote_code = False,
        streaming = False,
        is_dataset_tokenized = True,
        context_size = 5,
        use_cached_activations = False,
        cached_activations_path = None,  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    
        # SAE Parameters
        d_in = model.cfg.d_model,
        d_sae = None,
        b_dec_init_method = "geometric_median",
        expansion_factor = 4,
        activation_fn = "relu",  # relu, tanh-relu
        normalize_sae_decoder = True,
        noise_scale = 0.0,
        from_pretrained_path = None,
        apply_b_dec_to_input = False,
        decoder_orthogonal_init = False,
        decoder_heuristic_init = False,
        init_encoder_as_decoder_transpose = False,
    
        # Activation Store Parameters
        training_tokens = training_tokens,
        finetuning_tokens = 0,
        store_batch_size_prompts = 4,
        normalize_activations = "none",  # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
    
        # Misc
        device = device,
        act_store_device = "with_model",  # will be set by post init if with_model
        seed = 42,
        dtype = "float32",  # type: ignore #
        prepend_bos = False,
    
        # Performance - see compilation section of lm_runner.py for info
        autocast = False,  # autocast to autocast_dtype during training
        autocast_lm = False,  # autocast lm during activation fetching
        compile_llm = False,  # use torch.compile on the LLM
        llm_compilation_mode = None,  # which torch.compile mode to use
        compile_sae = False,  # use torch.compile on the SAE
        sae_compilation_mode = None,
    
        # Training Parameters
    
        ## Batch size
        train_batch_size_tokens = 320//4,
    
        ## Adam
        adam_beta1 = 0.9,
        adam_beta2 = 0.999,
    
        ## Loss Function
        mse_loss_normalization = None,
        l1_coefficient = l1_coeff,
        lp_norm = 1,
        scale_sparsity_penalty_by_decoder_norm = False,
        l1_warm_up_steps = 0,
    
        ## Learning Rate Schedule
        lr = 3e-4,
        lr_scheduler_name = "constant",  # constant, cosineannealing, cosineannealingwarmrestarts
        lr_warm_up_steps = 0,
        lr_end = None,  # only used for cosine annealing, default is lr / 10
        lr_decay_steps = 0,
        n_restart_cycles = 1,  # used only for cosineannealingwarmrestarts
    
        ## FineTuning
        finetuning_method = None,  # scale, decoder or unrotated_decoder
    
        # Resampling protocol args
        use_ghost_grads = True,  # want to change this to true on some timeline.
        feature_sampling_window = 2000,
        dead_feature_window = 1000,  # unless this window is larger feature sampling,
        dead_feature_threshold = 1e-8,
    
        # Evals
        n_eval_batches = 10,
        eval_batch_size_prompts = None,  # useful if evals cause OOM
    
        # WANDB
        log_to_wandb = True,
        log_activations_store_to_wandb = False,
        log_optimizer_state_to_wandb = False,
        wandb_project = "benchmark_saes",
        wandb_id = None,
        run_name = None,
        wandb_entity = None,
        wandb_log_frequency = 10,
        eval_every_n_wandb_logs = 100000000000, # Make this a really big number; currently fails because it tries to compute CE loss.
        # Misc
        resume = False,
        n_checkpoints = 5,
        checkpoint_path = f"$HOME/persistent-storage/tracr_saes/task_{task_idx}_checkpoints",
        verbose = True,
        model_kwargs = dict(),
        model_from_pretrained_kwargs = dict(),
        sae_lens_version = str(sae_lens.__version__),
        sae_lens_training_version = str(sae_lens.__version__),
    )

runner_cfg = make_sae_lens_config( "blocks.0.hook_resid_pre", 0, l1_coeff=1e-1, training_tokens=1_500_000)

Run name: 48-L1-0.1-LR-0.0003-Tokens-1.500e+06
n_tokens_per_buffer (millions): 0.0004
Lower bound: n_contexts_per_buffer (millions): 8e-05
Total training steps: 18750
Total wandb updates: 1875
n_tokens_per_feature_sampling_window (millions): 0.8
n_tokens_per_dead_feature_window (millions): 0.4
We will reset the sparsity calculation 9 times.
Number tokens in sparsity calculation window: 1.60e+05
Using Ghost Grads.


In [18]:
model.tokenizer = hf_tokenizer
store = RepeatActivationsStore.from_config(model, runner_cfg, dataset=hf_dataset)
sae = TrainingSAE(runner_cfg)
trainer = SAETrainer(model, sae, store, save_checkpoint, cfg = runner_cfg)

if runner_cfg.log_to_wandb:
    wandb.init(
        project=runner_cfg.wandb_project,
        config=runner_cfg,
        name=runner_cfg.run_name,
        id=runner_cfg.wandb_id,
    )
trainer.fit()

wandb.finish()

## Inference

In [29]:
def get_top_k(activations, k=20):
    # Reshape
    batch, ctx, feat = activations.shape
    reshaped = activations.view(batch * ctx, feat)
    
    # Get the top-k samples and their indices
    top_samples = torch.topk(reshaped, dim=0, k=k)
    top_values = top_samples.values
    top_indices = top_samples.indices
    
    # print(top_indices)  # Print the shape of the top-k values
    
    # Compute the original batch and ctx positions
    original_batch_indices = top_indices // ctx
    original_ctx_indices = top_indices % ctx
    
    return top_values.cpu(), original_batch_indices.cpu(), original_ctx_indices.cpu()


# Function to update the DataFrame based on the selected feature
def update_dataframe(feat, top_values, batch_indices, ctx_indices, live_features):
    feat = live_features[feat]
    k = 20
    
    print(f'feature {feat}')
    
    activations = top_values[:, feat]
    best_tokens = tokens[batch_indices[:, feat]]
    active_idx = ctx_indices[:, feat]
    active_token = best_tokens[range(k), active_idx]
    active_tokens = hf_tokenizer.decode(active_token)
    full_strs = [hf_tokenizer.decode(t) for t in best_tokens]
    
    info = [
        {
            'tokens': toks,
            'activation': act.item(),
            'tok_idx': idx.item(),
        }
        for act, idx, toks in zip(activations, active_idx, full_strs)
    ]
    
    # Create DataFrame
    df = pd.DataFrame(info)
    display(df.head(10))

In [27]:
tokens = torch.Tensor(hf_dataset['tokens']).to(int)
logits, cache = model.run_with_cache(tokens, prepend_bos=False)

In [97]:
# Example usage
activations = sae.encode(cache[sae.cfg.hook_name])
top_values, batch_indices, ctx_indices = get_top_k(activations)
print("Top values shape:", top_values.shape)
print("Batch indices shape:", batch_indices.shape)
print("Context indices shape:", ctx_indices.shape)

print(activations.shape)
# print(batch_indices)

Top values shape: torch.Size([20, 48])
Batch indices shape: torch.Size([20, 48])
Context indices shape: torch.Size([20, 48])
torch.Size([320, 5, 48])


In [110]:

live_features = top_values[0,:].nonzero().flatten()
print(live_features.shape[0])

# Define the dropdown menu for 'feat'
feat_dropdown = widgets.Dropdown(
    options=range(live_features.shape[0]),
    value=0,
    description='Feature:',
)


dataframe_func = partial(
    update_dataframe, 
    top_values=top_values, 
    batch_indices=batch_indices, 
    ctx_indices=ctx_indices,
    live_features=live_features
)

# Create an interactive output widget
output = widgets.interactive_output(
    dataframe_func, 
    {
        'feat': feat_dropdown,
    }
)

# Display the dropdown menu and output
display(feat_dropdown, output)

21


Dropdown(description='Feature:', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19…

Output()

For model https://wandb.ai/evanhanders/benchmark_saes/artifacts/model/sae_case3_blocks.0.hook_resid_pre_48/v5

* Feature 0: x at idx=1
* 2: x at idx=1
* 3: UNK at idx=4
* 4: a at idx=3
* 6: BOS
* 7: c at idx=4
* 14: x at idx=3
* 15: UNK at idx=4
* 16: b at idx=2
* 17: x at idx=4
* 23: b at idx=3
* 26: x at idx=3
* 28: x at idx=2
* 30: c at idx=3
* 35: a at idx=1
* 38: c at idx=2
* 39: x at idx=4
* 41: x at idx=1
* 42: x at idx=2
* 46: b at idx=4
* 47: BOS

There are a lot of double-features:
* 2 & 41 (x @ 1)
* 3 & 15 (UNK @ 4)
* 6 & 47 (BOS)
* 14 & 26 (x @ 3)
* 17 & 39 (x @ 4)
* 28 & 42 (x @ 2)

Then there are these single features, ordered by letter:
* 35: a @ 1
* 16: b @ 2
* 23: b @ 3
* 46: b @ 4
* 38: c @ 2
* 30: c @ 3
* 7: c @ 4

## Block 0 MLP

In [32]:
hook_name = "blocks.1.hook_attn_out"
layer = 1
runner_cfg = make_sae_lens_config( hook_name, layer, l1_coeff=1e-1, training_tokens=1_500_000)

Run name: 48-L1-0.1-LR-0.0003-Tokens-1.500e+06
n_tokens_per_buffer (millions): 0.0004
Lower bound: n_contexts_per_buffer (millions): 8e-05
Total training steps: 18750
Total wandb updates: 1875
n_tokens_per_feature_sampling_window (millions): 0.8
n_tokens_per_dead_feature_window (millions): 0.4
We will reset the sparsity calculation 9 times.
Number tokens in sparsity calculation window: 1.60e+05
Using Ghost Grads.


In [33]:
model.tokenizer = hf_tokenizer
store = RepeatActivationsStore.from_config(model, runner_cfg, dataset=hf_dataset)
sae = TrainingSAE(runner_cfg)
trainer = SAETrainer(model, sae, store, save_checkpoint, cfg = runner_cfg)

if runner_cfg.log_to_wandb:
    wandb.init(
        project=runner_cfg.wandb_project,
        config=runner_cfg,
        name=runner_cfg.run_name,
        id=runner_cfg.wandb_id,
    )
trainer.fit()

wandb.finish()

  yield torch.tensor(
3700| MSE Loss 0.007 | L1 0.150:  20%|████████                                 | 296000/1500000 [01:06<04:32, 4414.97it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/xfae89eo/300080


7500| MSE Loss 0.004 | L1 0.124:  40%|████████████████▍                        | 600000/1500000 [02:14<03:19, 4520.27it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/xfae89eo/600080


11200| MSE Loss 0.005 | L1 0.180:  60%|███████████████████████▉                | 896000/1500000 [03:21<02:18, 4354.29it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/xfae89eo/900080


15000| MSE Loss 0.006 | L1 0.175:  80%|███████████████████████████████▏       | 1200000/1500000 [04:28<01:05, 4595.00it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/xfae89eo/1200080


18700| MSE Loss 0.003 | L1 0.143: 100%|██████████████████████████████████████▉| 1496000/1500000 [05:36<00:00, 4512.89it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/xfae89eo/final_1500000


18700| MSE Loss 0.003 | L1 0.143: 100%|██████████████████████████████████████▉| 1496000/1500000 [05:37<00:00, 4434.40it/s]


VBox(children=(Label(value='0.050 MB of 0.050 MB uploaded (0.009 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
details/current_l1_coefficient,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/auxiliary_reconstruction_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/ghost_grad_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▄▂▃▂▂▂▂▂▁▂▂▂▁▂▁▁▂▁▁▂▂▁▂▂▂▂▂▁▂▂▁▂▂▂▂▂▁▂▂
losses/mse_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/explained_variance,▁█████▇█████████████████████████████████
metrics/explained_variance_std,█▂▁▂▂▁▂▂▂▁▁▂▂▁▂▁▁▁▁▁▂▁▁▁▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁

0,1
details/current_l1_coefficient,0.1
details/current_learning_rate,0.0003
details/n_training_tokens,1500000.0
losses/auxiliary_reconstruction_loss,0.0
losses/ghost_grad_loss,0.00034
losses/l1_loss,1.42284
losses/mse_loss,0.00425
losses/overall_loss,0.14687
metrics/explained_variance,0.99573
metrics/explained_variance_std,0.00645


## Inference

In [35]:
# Example usage
activations = sae.encode(cache[sae.cfg.hook_name])
top_values, batch_indices, ctx_indices = get_top_k(activations)
print("Top values shape:", top_values.shape)
print("Batch indices shape:", batch_indices.shape)
print("Context indices shape:", ctx_indices.shape)

print(activations.shape)
# print(batch_indices)

Top values shape: torch.Size([20, 48])
Batch indices shape: torch.Size([20, 48])
Context indices shape: torch.Size([20, 48])
torch.Size([320, 5, 48])


In [36]:
live_features = top_values[0,:].nonzero().flatten()
print(f'Live features: {live_features.shape[0]}')

# Define the dropdown menu for 'feat'
feat_dropdown = widgets.Dropdown(
    options=range(live_features.shape[0]),
    value=0,
    description='Feature:',
)

dataframe_func = partial(
    update_dataframe, 
    top_values=top_values, 
    batch_indices=batch_indices, 
    ctx_indices=ctx_indices,
    live_features=live_features
)

# Create an interactive output widget
output = widgets.interactive_output(
    dataframe_func, 
    {
        'feat': feat_dropdown,
    }
)

# Display the dropdown menu and output
display(feat_dropdown, output)

Live features: 20


Dropdown(description='Feature:', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19…

Output()

UNK / BOS:

* 30: Fires on UNK
* 32: Fires on UNK

On x:

* 7: x occurring after a bunch of other x's.
* 8: x @ 1.
* 9: x @ 2 with x @ 1.
* 23: x @ 4 when context starts with c or a?
* 25: x @ 1.
* 33: Fires at the end of strings of x's
* 44: Fires on x @ 1
* 45: Mostly fires on the end of strings of 3 x's?
* 47: fires on c @ 4 when there are no x's in context.

On not-x:

* 3: Non-x after full string of x's.
* 5: Non-x after full string of non-x.
* 13: a @ 1
* 18: a @ 1
* 26: Fires on not-x when the beginning of the sequence is all x's.
* 29: Fires on not-x @ 4 when seq starts (not-x) x x
* 39: just like 29.
* 41: Fires on a @ 1
* 

Other:

* 22: Fires near the end of sequences like (non-x) x x

## Block 0 MLP

In [19]:
hook_name = "blocks.0.hook_mlp_out"
runner_cfg = make_sae_lens_config( hook_name, 0, l1_coeff=1e-1, training_tokens=1_500_000)

Run name: 48-L1-0.1-LR-0.0003-Tokens-1.500e+06
n_tokens_per_buffer (millions): 0.0004
Lower bound: n_contexts_per_buffer (millions): 8e-05
Total training steps: 18750
Total wandb updates: 1875
n_tokens_per_feature_sampling_window (millions): 0.8
n_tokens_per_dead_feature_window (millions): 0.4
We will reset the sparsity calculation 9 times.
Number tokens in sparsity calculation window: 1.60e+05
Using Ghost Grads.


In [21]:
model.tokenizer = hf_tokenizer
store = RepeatActivationsStore.from_config(model, runner_cfg, dataset=hf_dataset)
sae = TrainingSAE(runner_cfg)
trainer = SAETrainer(model, sae, store, save_checkpoint, cfg = runner_cfg)

if runner_cfg.log_to_wandb:
    wandb.init(
        project=runner_cfg.wandb_project,
        config=runner_cfg,
        name=runner_cfg.run_name,
        id=runner_cfg.wandb_id,
    )
trainer.fit()

wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mevanhanders[0m. Use [1m`wandb login --relogin`[0m to force relogin


  yield torch.tensor(
3700| MSE Loss 0.009 | L1 0.092:  20%|████████                                 | 296000/1500000 [00:51<03:32, 5661.09it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/7dwx0lya/300080


7500| MSE Loss 0.005 | L1 0.090:  40%|████████████████▍                        | 600000/1500000 [01:43<02:34, 5817.62it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/7dwx0lya/600080


11200| MSE Loss 0.005 | L1 0.086:  60%|███████████████████████▉                | 896000/1500000 [02:34<01:45, 5730.62it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/7dwx0lya/900080


15000| MSE Loss 0.005 | L1 0.068:  80%|███████████████████████████████▏       | 1200000/1500000 [03:28<00:50, 5991.75it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/7dwx0lya/1200080


18700| MSE Loss 0.005 | L1 0.065: 100%|██████████████████████████████████████▉| 1496000/1500000 [04:20<00:00, 5750.64it/s]

saving $HOME/persistent-storage/tracr_saes/task_3_checkpoints/7dwx0lya/final_1500000


18700| MSE Loss 0.005 | L1 0.065: 100%|██████████████████████████████████████▉| 1496000/1500000 [04:21<00:00, 5717.20it/s]


VBox(children=(Label(value='0.050 MB of 0.050 MB uploaded (0.009 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
details/current_l1_coefficient,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/auxiliary_reconstruction_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/ghost_grad_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▄▃▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂▂▁▂▂▂▂▂▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂
losses/mse_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▂▂▁▂▁▁▂▁▂▂▂▁▂▂▁▁▂▁▁
metrics/explained_variance,▁▆▇█████████████████████████████████████
metrics/explained_variance_std,█▅▃▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁

0,1
details/current_l1_coefficient,0.1
details/current_learning_rate,0.0003
details/n_training_tokens,1500000.0
losses/auxiliary_reconstruction_loss,0.0
losses/ghost_grad_loss,0.00033
losses/l1_loss,0.78859
losses/mse_loss,0.00398
losses/overall_loss,0.08316
metrics/explained_variance,0.99307
metrics/explained_variance_std,0.00353


## Inference

In [23]:
# Example usage
activations = sae.encode(cache[sae.cfg.hook_name])
top_values, batch_indices, ctx_indices = get_top_k(activations)
print("Top values shape:", top_values.shape)
print("Batch indices shape:", batch_indices.shape)
print("Context indices shape:", ctx_indices.shape)

print(activations.shape)
# print(batch_indices)

Top values shape: torch.Size([20, 48])
Batch indices shape: torch.Size([20, 48])
Context indices shape: torch.Size([20, 48])
torch.Size([320, 5, 48])


In [31]:
live_features = top_values[0,:].nonzero().flatten()
print(live_features.shape[0])

# Define the dropdown menu for 'feat'
feat_dropdown = widgets.Dropdown(
    options=range(live_features.shape[0]),
    value=0,
    description='Feature:',
)

dataframe_func = partial(
    update_dataframe, 
    top_values=top_values, 
    batch_indices=batch_indices, 
    ctx_indices=ctx_indices,
    live_features=live_features
)

# Create an interactive output widget
output = widgets.interactive_output(
    dataframe_func, 
    {
        'feat': feat_dropdown,
    }
)

# Display the dropdown menu and output
display(feat_dropdown, output)

18


Dropdown(description='Feature:', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), value…

Output()

UNK / BOS:

* 0: UNK on 4
* 10: UNK on 4
* 22: UNK on 4
* 25: BOS

Counting x:

* 13: x on 1
* 36: x on 1
* 38: x on 1
* 30: x on 3 (no x before)
* 33: x on 2 (with x on 1)
* 39: x on 2 (with x on 1)
* 11: x on 4 (fires more strongly for more x's early in context?)

Not counting x:

* 3: b on 3 (no x in 1-3)
* 4: c on 3 (no x in 1-3)
* 17: c on 2 (no x in 1)
* 19: b on 1
* 23: b on 1
* 35: c on 4 (with no x before)
* 37: b or a on 4 (with many x's before)

# Loading from wandb

In [25]:
import wandb

# Initialize the wandb API
api = wandb.Api()



# Get the artifact from the old run
artifact = api.artifact('evanhanders/benchmark_saes/sae_case3_blocks.0.hook_mlp_out_48:latest')

# Download the artifact to a specified directory
artifact_mlp_out = artifact.download("./mlp_out_0")


artifact = api.artifact('evanhanders/benchmark_saes/sae_case3_blocks.1.hook_resid_pre_48:latest')

# Download the artifact to a specified directory
artifact_residual_1 = artifact.download("./hook_resid_pre_1")

[34m[1mwandb[0m:   2 of 2 files downloaded.  
[34m[1mwandb[0m:   2 of 2 files downloaded.  


In [29]:
from sae_lens import SAE

sae_mlp = SAE.load_from_pretrained(artifact_mlp_out, device=device)
sae_residual = SAE.load_from_pretrained(artifact_residual_1, device=device)

In [43]:
model.cfg.tokenizer_prepends_bos=False
model.cfg.default_prepend_bos=False

In [49]:
tokens = torch.Tensor(hf_dataset['tokens']).to(int)
logits, cache = model.run_with_cache(tokens, prepend_bos=False)

In [63]:
activations_mlp = sae_mlp.encode(cache[sae_mlp.cfg.hook_name])
activations_resid_1 = sae_residual.encode(cache[sae_residual.cfg.hook_name])

MLP-0 features:
1 - on tok 4 when 1-3 are all x
2 - weak 'b' on tok 1
6 - 'x' on pos 2
7 - very weak, pos 4, when there are no x's?
12 - 'x' on pos 1
16 - weak 'b' on pos 1.
18 - BOS
19 - Strong fire on x on tok 4
20 - weak c on tok 3
21 - BOS
22 - weak c on tok 2
23 - weak 'x' on pos 4 or 2.
26 - BOS
28 - BOS
30 - weak 'c' on tok 3
33 - 'x' on tok 3.
34 - 'x' on tok 1
39 - 'b' on tok 3
40 - 'x' on tok 2
42 - 'x' on tok 3
43 - BOS
46 - 'c' on tok 1
