## A study of SAE features

Problem description: indirect object identification

- Writing functions for attention SAEs to see what source token contributes to the activation

## Setup and imports

### Imports

Directly from the tutorial

In [1]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
# from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

# There's a single utils & tests file for both parts 3.1 & 3.2
# import part31_superposition_and_saes.utils as utils
# from plotly_utils import imshow, line

### Helper functions

In [4]:
import inspect

def s(tensor):
    """
    Simple helper function to print the shape of a tensor.
    
    Args:
        tensor: A PyTorch tensor or any object with a .shape attribute
    
    Example:
        attnout = torch.randn(32, 768)
        s(attnout)  # Output: shape of attnout is torch.Size([32, 768])
    """
    # Get the name of the variable from the caller's frame
    frame = inspect.currentframe().f_back
    calling_line = inspect.getframeinfo(frame).code_context[0].strip()
    # Extract variable name from the function call
    # This looks for s(variable_name) pattern
    import re
    match = re.search(r's\((.*?)\)', calling_line)
    if match:
        var_name = match.group(1).strip()
    else:
        var_name = "tensor"
        
    if hasattr(tensor, 'shape'):
        print(f"Shape of [{var_name}]: {tensor.shape}")
    else:
        print(f"{var_name} has no shape attribute. Type: {type(tensor)}")


### Load in the language model

We loaded GPT-2-small and Gemma-2.2b, and their corresponding trained SAEs. Specifically we are using the SAE release

```python
gemmascope_sae_release = "gemma-scope-2b-pt-res-canonical"
gemmascope_sae_id = "layer_20/width_16k/canonical"
gemma_2_2b_sae = SAE.from_pretrained(gemmascope_sae_release, gemmascope_sae_id, device=str(device))[0]
```

Also expand to see detailed architectural info on `gemma_2_2b`.

In [5]:
t.set_grad_enabled(False)

gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

gpt2_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    device=str(device),
)

# print(tabulate(gpt2_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

Loaded pretrained model gpt2-small into HookedTransformer




This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)



In [4]:
## the task of arithmetic does not work well for gpt2_small. Switching to gemma
HUGGINGFACE_KEY = 0
import os
os.environ["HUGGINGFACE_KEY"] = HUGGINGFACE_KEY

USING_GEMMA = os.environ.get("HUGGINGFACE_KEY") is not None

if USING_GEMMA:
    !huggingface-cli login --token $HUGGINGFACE_KEY
    gemma_2_2b = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)

    gemmascope_sae_release = "gemma-scope-2b-pt-res-canonical"
    gemmascope_sae_id = "layer_20/width_16k/canonical"
    gemma_2_2b_sae = SAE.from_pretrained(gemmascope_sae_release, gemmascope_sae_id, device=str(device))[0]
else:
    print("Please supply your Hugging Face API key before running this cell")

TypeError: str expected, not int

In [7]:
## Try loading SAEs at other layers as well?
print(tabulate(gemma_2_2b.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

┌────────────────────────────────────┬─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ name                               │ value                                                                                                                                                                                                                                                                                                                                                                                                           │
├────────────────────────────────────┼────────────────────────────────────────────────────────────────────────────────

In [9]:
print(tabulate(gemma_2_2b_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

┌──────────────────────────────┬──────────────────────────────────┐
│ name                         │ value                            │
├──────────────────────────────┼──────────────────────────────────┤
│ architecture                 │ jumprelu                         │
│ d_in                         │ 2304                             │
│ d_sae                        │ 16384                            │
│ activation_fn_str            │ relu                             │
│ apply_b_dec_to_input         │ False                            │
│ finetuning_scaling_factor    │ False                            │
│ context_size                 │ 1024                             │
│ model_name                   │ gemma-2-2b                       │
│ hook_name                    │ blocks.20.hook_resid_post        │
│ hook_layer                   │ 20                               │
│ hook_head_index              │                                  │
│ prepend_bos                  │ True           

## Check performance

### Generate the test dataset

The prompts are stored in file `arithmetic_questions.txt` with answers stored in `ans_list` as a `namedtuple`. To generate prompts, use

```python
q_list, a_list = prompt_generator(with_instructions=False, n_batch=10)
# gemma_2_2b.generate(q_list[3], max_new_tokens=10, do_sample=False)
qs_tkns = gemma_2_2b.to_tokens(q_list, padding_side="left")
ans_tkns = gemma_2_2b.generate(qs_tkns, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)
```

In [13]:
MAX_NEW_TOKENS = 4

In [14]:
from collections import namedtuple

AnsConfig = namedtuple("AnsConfig", ["a", "b", "operation", "ans"])

def prompt_generator(
    n_range: int = 100,
    op: list[str] = ["plus"],
    n_batch: int = 100,
    return_type: Literal["string", "token"] = "string",
    write_to_file: bool = False,
    file_path: str = "addition_prompts.txt",
    with_instructions: bool = False,
) -> tuple[list[str], list[AnsConfig]]:
    """
    Generates a list of arithmetic questions and their answers.
    """
    a = t.randint(0, n_range, (n_batch,))
    b = t.randint(0, n_range, (n_batch,))
    
    a_instr = t.randint(0, n_range, (n_batch,))
    b_instr = t.randint(0, n_range, (n_batch,))
    
    ans_list = []
    q_list = []
    
    
    with open(file_path, "w") as f:
        for i in range(n_batch):
            
            operation = random.choice(op)

            # log the correct answer
            if operation == "plus":
                answer = a[i] + b[i]
                inst_answer = a_instr[i] + b_instr[i]
            elif operation == "minus":
                answer = a[i] - b[i]
                inst_answer = a_instr[i] - b_instr[i]
            elif operation == "times":
                answer = a[i] * b[i]
                inst_answer = a_instr[i] * b_instr[i]
            # elif operation == "divided by":
            #     answer = a[i] / b[i]
            
            if with_instructions:
                q_list.append(
                    f"{a_instr[i].item()} {operation} {b_instr[i].item()} is {inst_answer.item()}, {a[i].item()} {operation} {b[i].item()} is"
                )
            else:
                q_list.append(
                    f"{a[i].item()} {operation} {b[i].item()} is"
                )

            if write_to_file:
                f.write(q_list[-1] + "\n")
            
            ans_list.append(
                AnsConfig(
                    a=a[i].item(),
                    b=b[i].item(),
                    operation=operation,
                    ans=answer.item()
                )
            )
    
    return q_list, ans_list

### Testing functionality
q_list, a_list = prompt_generator(with_instructions=False, n_batch=10)
# gemma_2_2b.generate(q_list[3], max_new_tokens=10, do_sample=False)
qs_tkns = gemma_2_2b.to_tokens(q_list, padding_side="left")
ans_tkns = gemma_2_2b.generate(qs_tkns, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)

  0%|          | 0/4 [00:00<?, ?it/s]

In [15]:
### ans = gemma_2_2b.to_str_tokens(ans_tkns)
### check the answer is correct

ans = [ans_tkns[i, -MAX_NEW_TOKENS:].tolist() for i in range(ans_tkns.shape[0])]
ans = gemma_2_2b.to_string(ans)
correct = [
    str(a_list[i].ans) in ans[i] for i in range(len(ans))
]
correct = t.tensor(correct)
acc = correct.sum() / correct.shape[0]
print(f"Accuracy: {acc.item():.2%}")

Accuracy: 80.00%


In [None]:
# Comparing the accuracy of instructed vs non-instructed prompts:
qs_instr, a_list_instr = prompt_generator(with_instructions=True, n_batch=100)
qs_no_instr, a_list_no_instr = prompt_generator(with_instructions=False, n_batch=100)

### Examine the model's behaviour

#### GPT-2-small does not perform

In [None]:
prompt = (
    "3 plus 5 is 8, 4 plus 6 is"
)
gpt2(prompt, return_type="logits", prepend_bos=True)

gpt2.generate(
    prompt,
    temperature=0.2,
    top_p=0.9,
    do_sample=True,
    max_new_tokens=20,
)

In [None]:
with open("arithmetic_questions.txt") as f:
    questions = f.readlines()
    questions = gpt2.to_tokens(questions)
    questions = questions[:, :-1]
    print(questions.shape)

logits: Tensor = gpt2(
    questions,
    return_type="logits",
)

answer_logits = logits[:, -1, :].argmax(-1) # prediction of the model
answers = gpt2.to_str_tokens(answer_logits)
print(answers)

#### Gemma-2.2b performs with success probability close to 1

Observation: the tokenizer mostly treats individual number as individual tokens, and within the prompt there are no words that span multiple tokens.

- Comparison: success probability with/without the previous prompt? (97% compared to 87% for 100 uniformly random prompts)

In [None]:
prompt = "The answer to 9 plus 9 is"
str_ans = gemma_2_2b.generate(prompt, max_new_tokens=100, do_sample=False)
print(str_ans)

## Analysis with SAE

In [16]:
# find max-activating latents:

def get_topk_activated_latents(
    sae: SAE = gemma_2_2b_sae,
    k: int = 100,
    prompts: list[str] = ans_tkns, # use the answers to examine which latents are activated at near the answer tokens
    plot_hist: bool = False,
):
    logits, cache = gemma_2_2b.run_with_cache_with_saes(
        prompts,
        saes=[sae],
        return_type="logits",
    )
    
    latent_acts = cache[
        f"{sae.cfg.hook_name}.hook_sae_acts_pre"
    ][:, -MAX_NEW_TOKENS:, :].mean(1)
    
    top_activations = latent_acts.topk(k, dim=-1)
    
    if plot_hist:
        px.histogram(
            latent_acts.cpu().numpy(),
            title=f"Latent activations at the final token position",
            labels={"value": "Activation"},
            width=800,
        ).update_layout(showlegend=False).show()
        
    
    return top_activations

top_activations = get_topk_activated_latents(gemma_2_2b_sae, k=6, plot_hist=False)

In [29]:
# fetch the top latents that has been activated
from collections import Counter
top_act_inds = top_activations.indices.flatten().tolist()
occ = Counter(top_act_inds)
top_3 = occ.most_common(5)
print("Top 3 most activated latents:", top_3)

Top 3 most activated latents: [(12132, 9), (9768, 9), (8684, 6), (435, 5), (7900, 4)]


## Other functions

...that are not used at current stage but might come in handy later, imported from tutorial.

In [None]:
def get_k_largest_indices(
    x: Float[Tensor, "batch seq"],
    k: int,
    buffer: int = 0,
    no_overlap: bool = True,
) -> Int[Tensor, "k 2"]:
    """
    Returns the tensor of (batch, seqpos) indices for each of the top k elements in the tensor x.

    Args:
        buffer:     We won't choose any elements within `buffer` from the start or end of their seq (this helps if we
                    want more context around the chosen tokens).
        no_overlap: If True, this ensures that no 2 top-activating tokens are in the same seq and within `buffer` of
                    each other.
    """
    assert buffer * 2 < x.size(1), "Buffer is too large for the sequence length"
    assert not no_overlap or k <= x.size(0), "Not enough sequences to have a different token in each sequence"

    if buffer > 0:
        x = x[:, buffer:-buffer]

    indices = x.flatten().argsort(-1, descending=True)
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer

    if no_overlap:
        unique_indices = t.empty((0, 2), device=x.device).long()
        while len(unique_indices) < k:
            unique_indices = t.cat((unique_indices, t.tensor([[rows[0], cols[0]]], device=x.device)))
            is_overlapping_mask = (rows == rows[0]) & ((cols - cols[0]).abs() <= buffer)
            rows = rows[~is_overlapping_mask]
            cols = cols[~is_overlapping_mask]
        return unique_indices

    return t.stack((rows, cols), dim=1)[:k]

def index_with_buffer(
    x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, "k *buffer_x2_plus1"]:
    """
    Indexes into `x` with `indices` (which should have come from the `get_k_largest_indices` function), and takes a
    +-buffer range around each indexed element. If `indices` are less than `buffer` away from the start of a sequence
    then we just take the first `2*buffer+1` elems (same for at the end of a sequence).

    If `buffer` is None, then we don't add any buffer and just return the elements at the given indices.
    """
    rows, cols = indices.unbind(dim=-1)
    if buffer is not None:
        rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
        cols[cols < buffer] = buffer
        cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
        cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + t.arange(
            -buffer, buffer + 1, device=cols.device
        )
    return x[rows, cols]


def show_top_logits(
    model: HookedSAETransformer,
    sae: SAE,
    latent_idx: int,
    k: int = 10,
) -> None:
    """
    Displays the top & bottom logits for a particular latent.
    """
    with t.inference_mode():
        # in the SAE for gpt2 there's only one instance
        latent_vec = sae.W_dec[latent_idx]
        token_logit = latent_vec @ model.W_U # adding bias seems to ruin the model...
        
        token_logit = token_logit.squeeze()
        top_k_logits = token_logit.topk(k=k)
        bottom_k_logits = token_logit.topk(k=k, largest=False)
        
        top_k_tokens = model.to_str_tokens(top_k_logits.indices)
        bottom_k_tokens = model.to_str_tokens(bottom_k_logits.indices)
        
        print(f"Top {k} logits:")
        for i, (logit, token) in enumerate(zip(top_k_logits.values, top_k_tokens)):
            print(f"{i+1}. {token} ({logit:.2f})")
        
        print(f"\nBottom {k} logits:")
        for i, (logit, token) in enumerate(zip(bottom_k_logits.values, bottom_k_tokens)):
            print(f"{i+1}. {token} ({logit:.2f})")

In [None]:
attn_saes = {
    layer: SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=str(device),
    )[0]
    for layer in range(gpt2.cfg.n_layers)
}

In [None]:
@dataclass
class AttnSeqDFA:
    act: float
    str_toks_dest: list[str]
    str_toks_src: list[str]
    dest_pos: int
    src_pos: int


def display_top_seqs_attn(data: list[AttnSeqDFA]):
    """
    Same as previous function, but we now have 2 str_tok lists and 2 sequence positions to highlight, the first being
    for top activations (destination token) and the second for top DFA (src token). We've given you a dataclass to help
    keep track of this.
    """
    table = Table(
        "Top Act",
        "Src token DFA (for top dest token)",
        "Dest token",
        title="Max Activating Examples",
        show_lines=True,
    )
    for seq in data:
        formatted_seqs = [
            repr(
                "".join(
                    [f"[b u {color}]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)]
                )
                .replace("�", "")
                .replace("\n", "↵")
            )
            for str_toks, seq_pos, color in [
                (seq.str_toks_src, seq.src_pos, "dark_orange"),
                (seq.str_toks_dest, seq.dest_pos, "green"),
            ]
        ]
        table.add_row(f"{seq.act:.3f}", *formatted_seqs)
    rprint(table)


str_toks = [" one", " two", " three", " four"]
example_data = [
    AttnSeqDFA(act=0.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=0, src_pos=0),
    AttnSeqDFA(act=1.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=1, src_pos=1),
    AttnSeqDFA(act=2.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=2, src_pos=0),
]
display_top_seqs_attn(example_data)



_, cache = gpt2.run_with_cache_with_saes(
    tokens,
    saes=attn_saes[9], # now running with attn_saes
    stop_at_layer=attn_saes[9].cfg.hook_layer + 1,
)


In [None]:
def fetch_max_activating_examples(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 5,
    k: int = 10,
    buffer: int = 10,
) -> list[tuple[float, list[str], int]]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    list_of_top_values_with_context = []
    for i in tqdm(range(total_batches)):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[f"{sae.cfg.hook_name}.hook_sae_acts_post"],
        )
        acts = cache[f"{sae.cfg.hook_name}.hook_sae_acts_post"][..., latent_idx] # for a specific index
        # print(acts.shape)
        
        top_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        top_values_with_context = index_with_buffer(tokens, top_indices, buffer=buffer)
        
        
        list_of_top_values_with_context.extend(
            list(
                zip(
                    [acts[top_indices[i, 0], top_indices[i, 1]] for i in range(k)],
                    [model.to_str_tokens(top_values_with_context[i]) for i in range(k)],
                    [buffer] * k
                )
            )
        )
        
    return sorted(list_of_top_values_with_context, key=lambda x: x[0], reverse=True)[:k]
        


In [None]:
def fetch_max_activating_examples_attn(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 250,
    k: int = 10,
    buffer: int = 10,
) -> list[AttnSeqDFA]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    
    ## Second attempt
    data = []
    for i in tqdm(range(total_batches)):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[
                f"{sae.cfg.hook_name}.hook_sae_acts_pre",
                f"blocks.{sae.cfg.hook_layer}.attn.hook_v",
                f"blocks.{sae.cfg.hook_layer}.attn.hook_pattern",
            ],
        )
        acts = cache[f"{sae.cfg.hook_name}.hook_sae_acts_pre"][..., latent_idx]
        # shape: (batch, seq)
        
        # Find the maximally activating tokens with
        top_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        top_values_with_context = index_with_buffer(tokens, top_indices, buffer=buffer)
        # print(f"top_values_with_context shape: {top_values_with_context.shape}")
        
        # Find the which source tokens are most responsible for the activations
        val_vec = cache[f"blocks.{sae.cfg.hook_layer}.attn.hook_v"]
        # shape (batch, seq, nh, dh)
        attn_probs = cache[f"blocks.{sae.cfg.hook_layer}.attn.hook_pattern"]
        # shape (batch, nh, seq_dest, seq_src)
        
        # Limit the attention prob and value vectors to the top activations
        attn_probs = einops.rearrange(attn_probs, "batch nh seq_des seq_src -> batch seq_des nh seq_src")[top_indices[:, 0], top_indices[:, 1]]
        # shape (topk, nh, seq_src)
        val_vec = val_vec[top_indices[:, 0], :, :, :]
        # shape (topk, seq_src, nh, dh)
        
        val_weighted = einops.einsum(
            val_vec, attn_probs,
            "topk seq_src nh dh, topk nh seq_src -> topk seq_src nh dh"
        )
        val_weighted = einops.rearrange(val_weighted, "topk seq_src nh dh -> topk seq_src (nh dh)")
        # or equivalently:
        # val_weighted = val_weighted.flatten(-2, -1)
        
        scores = einops.einsum(
            val_weighted, sae.W_enc[:, latent_idx],
            "topk seq_src d_model, d_model -> topk seq_src"
        )
        # yet this is restricted to the top activations
        # shape (topk, seq_src)
        
        max_src_inds = scores.argmax(-1)
        # shape (topk,), need to reshape this by adding in the batch index
        max_src_inds = t.stack((top_indices[:, 0], max_src_inds), dim=1)
        
        # provide the buffered context around max_src_inds
        top_src_with_context = index_with_buffer(tokens, max_src_inds, buffer=buffer)
        # shape (batch_, seq_des_, 2*buffer+1)
        
        ### Finally, log the data
        # get the actiavtion values of the scores
        top_acts = scores[t.arange(max_src_inds.shape[0]), max_src_inds[:, 1]]
        # print(f"top_acts shape: {top_acts.shape}")
        
        # get the str_toks for the top activations
        str_toks_dest = [model.to_str_tokens(toks) for toks in top_values_with_context]
        str_toks_src = [model.to_str_tokens(toks) for toks in top_src_with_context]
        dest_pos = [buffer] * k
        src_pos = [buffer] * k
        
        for act, des_token_w_ctx, src_token_w_ctx, des_index, src_index in zip(
            top_acts, str_toks_dest, str_toks_src, dest_pos, src_pos
        ):
            data.append(
                AttnSeqDFA(act=act.item(), str_toks_dest=des_token_w_ctx, str_toks_src=src_token_w_ctx, dest_pos=des_index, src_pos=src_index)
            )
    
    return sorted(data, key=lambda x: x.act, reverse=True)[:k]
        
        

# Test your function: compare it to dashboard above (max DFA should come from src toks like " guns", " firearms")
layer = 9
data = fetch_max_activating_examples_attn(gpt2, attn_saes[layer], gpt2_act_store, latent_idx=2)
display_top_seqs_attn(data)