### install necessary dependencies

In [2]:
%pip install sae-lens transformer-lens sae-dashboard huggingface_hub[cli] tabulate openai ipywidgets

[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
# !pip install accelerate
# in terminal
# apt install unzip

In [3]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
from tabulate import tabulate

import sae_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE,HookedSAETransformer,ActivationsStore
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from torch import Tensor, nn
import einops
from rich import print as rprint
from rich.table import Table
from tqdm.auto import tqdm
import pandas as pd
import requests
from typing import Any, Callable, Literal, TypeAlias
from openai import OpenAI
from huggingface_hub import interpreter_login
import os
import sys


In [5]:
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
root = "/workspace"

if not os.path.exists(f"{root}/{chapter}"):
    !wget https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/main.zip
    !unzip {root}/main.zip 'ARENA_3.0-main/chapter1_transformer_interp/exercises/*'
    !mv {root}/{repo}-main/{chapter} {root}/{chapter}
    !rm {root}/main.zip
    !rmdir {root}/{repo}-main

# !touch /content/chapter1_transformer_interp/exercises/part32_superposition_and_saes/__init__.py
# !touch /content/chapter1_transformer_interp/exercises/__init__.py
sys.path.append(f"{root}/{chapter}/exercises")

--2025-01-19 15:52:38--  https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/main.zip
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/callummcdougall/ARENA_3.0/zip/refs/heads/main [following]
--2025-01-19 15:52:38--  https://codeload.github.com/callummcdougall/ARENA_3.0/zip/refs/heads/main
Resolving codeload.github.com (codeload.github.com)... 140.82.114.9
Connecting to codeload.github.com (codeload.github.com)|140.82.114.9|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘main.zip’

main.zip                [             <=>    ]  20.99M  5.37MB/s    in 4.2s    

2025-01-19 15:52:42 (5.01 MB/s) - ‘main.zip’ saved [22004849]

Archive:  /workspace/main.zip
dedb7d94423638cd3976da11bf9a40aa8b2dcdcb
   creating: ARENA_3.0-main/chapter1_transformer_interp/ex

In [6]:
import part31_superposition_and_saes.tests as part31_tests
import part31_superposition_and_saes.utils as part31_utils

In [7]:
# notebook_login()
from dotenv import load_dotenv
import os

load_dotenv()

HF_TOKEN=os.getenv("HF_TOKEN")
OPENAI_TOKEN = os.getenv("OPEN_API_KEY")

In [8]:
interpreter_login()


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|



Enter your token (input will not be visible):  ········
Add token as git credential? (Y/n)  n


In [9]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))

In [10]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
    columns=[
        "expected_var_explained",
        "expected_l0",
        "config_overrides",
        "conversion_func",
    ],
    inplace=True,
)
# df  # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model.
df.head()

Unnamed: 0,release,repo_id,model,saes_map,neuronpedia_id
gemma-2b-it-res-jb,gemma-2b-it-res-jb,jbloom/Gemma-2b-IT-Residual-Stream-SAEs,gemma-2b-it,{'blocks.12.hook_resid_post': 'gemma_2b_it_blo...,{'blocks.12.hook_resid_post': 'gemma-2b-it/12-...
gemma-2b-res-jb,gemma-2b-res-jb,jbloom/Gemma-2b-Residual-Stream-SAEs,gemma-2b,{'blocks.0.hook_resid_post': 'gemma_2b_blocks....,{'blocks.0.hook_resid_post': 'gemma-2b/0-res-j...
gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/average_l0_106': 'layer_...,"{'layer_10/width_131k/average_l0_106': None, '..."
gemma-scope-27b-pt-res-canonical,gemma-scope-27b-pt-res-canonical,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/canonical': 'layer_10/wi...,{'layer_10/width_131k/canonical': 'gemma-2-27b...
gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/average_l0_104': 'layer_0/...,"{'layer_0/width_16k/average_l0_104': None, 'la..."


In [11]:
df.loc[df.release == "gemma-2b-res-jb"]

Unnamed: 0,release,repo_id,model,saes_map,neuronpedia_id
gemma-2b-res-jb,gemma-2b-res-jb,jbloom/Gemma-2b-Residual-Stream-SAEs,gemma-2b,{'blocks.0.hook_resid_post': 'gemma_2b_blocks....,{'blocks.0.hook_resid_post': 'gemma-2b/0-res-j...


In [12]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Load the LLM
model = HookedSAETransformer.from_pretrained_no_processing(
    "gemma-2-2b",
    device = device,
    torch_dtype = torch.float16,
    device_map = "auto"
)

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Loaded pretrained model gemma-2-2b into HookedTransformer


In [34]:
# # Load the corresponding SAE
# release="gemma-scope-2b-pt-res-canonical"  # Replace with the correct release for your model
# sae_id="layer_20/width_16k/canonical"
# sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(
#     release=release,  # Replace with the correct release for your model
#     sae_id=sae_id,
#     device=device,
#     # device_map = "auto",
# )


# Load the corresponding SAE
release="gemma-scope-2b-pt-res-canonical"  # Replace with the correct release for your model
sae_id="layer_12/width_1m/canonical"
sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(
    release=release,  # Replace with the correct release for your model
    sae_id=sae_id,
    device=device,
    # device_map = "auto",
)



params.npz:   0%|          | 0.00/19.3G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [14]:
# latent_idx = 12082
latent_idx = 9

display_dashboard(sae_release=release, sae_id=sae_id, latent_idx=latent_idx)

https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/9?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [15]:
# activation store
gemma2_act_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=2048,
    n_batches_in_buffer=16,
    device=str(device),
)

# Example of how you can use this:
with torch.no_grad():
    tokens = gemma2_act_store.get_batch_tokens()
assert tokens.shape == (gemma2_act_store.store_batch_size_prompts, gemma2_act_store.context_size)

Downloading readme:   0%|          | 0.00/776 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



In [16]:
tokens.shape

torch.Size([8, 1024])

In [17]:
def get_k_largest_indices(
    x: Float[Tensor, "batch seq"],
    k: int,
    buffer: int = 0,
    no_overlap: bool = True,
) -> Int[Tensor, "k 2"]:
    """
    Returns the tensor of (batch, seqpos) indices for each of the top k elements in the tensor x.

    Args:
        buffer:     We won't choose any elements within `buffer` from the start or end of their seq (this helps if we
                    want more context around the chosen tokens).
        no_overlap: If True, this ensures that no 2 top-activating tokens are in the same seq and within `buffer` of
                    each other.
    """
    assert buffer * 2 < x.size(1), "Buffer is too large for the sequence length"
    assert not no_overlap or k <= x.size(0), "Not enough sequences to have a different token in each sequence"

    if buffer > 0:
        x = x[:, buffer:-buffer]

    indices = x.flatten().argsort(-1, descending=True)
    # extra_buffer = 10
    # values, indices = x.flatten().topk(k + extra_buffer,largest = True)
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer

    if rows.numel() == 0 or cols.numel() ==0:
        raise ValueError("No Valid activations found after applying buffer.")

    if no_overlap:
        unique_indices = torch.empty((0, 2), device=x.device).long()
        while len(unique_indices) < k:
            unique_indices = torch.cat((unique_indices, torch.tensor([[rows[0], cols[0]]], device=x.device)))
            is_overlapping_mask = (rows == rows[0]) & ((cols - cols[0]).abs() <= buffer)
            rows = rows[~is_overlapping_mask]
            cols = cols[~is_overlapping_mask]
        return unique_indices

    return torch.stack((rows, cols), dim=1)[:k]

# x = torch.arange(40, device=device).reshape((2, 20))
# x[0, 10] += 50  # 2nd highest value
# x[0, 11] += 100  # highest value
# x[1, 1] += 150  # not inside buffer (it's less than 3 from the start of the sequence)
# top_indices = get_k_largest_indices(x, k=2, buffer=3)
# rprint(top_indices)
# assert top_indices.tolist() == [[0, 11], [0, 10]]



def index_with_buffer(
    x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, "k *buffer_x2_plus1"]:
    """
    Indexes into `x` with `indices` (which should have come from the `get_k_largest_indices` function), and takes a
    +-buffer range around each indexed element. If `indices` are less than `buffer` away from the start of a sequence
    then we just take the first `2*buffer+1` elems (same for at the end of a sequence).

    If `buffer` is None, then we don't add any buffer and just return the elements at the given indices.
    """
    rows, cols = indices.unbind(dim=-1)
    if buffer is not None:
        rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
        cols[cols < buffer] = buffer
        cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
        cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + torch.arange(
            -buffer, buffer + 1, device=cols.device
        )
    return x[rows, cols]


# x_top_values_with_context = index_with_buffer(x, top_indices, buffer=3)
# assert x_top_values_with_context[0].tolist() == [8, 9, 10 + 50, 11 + 100, 12, 13, 14]  # highest value in the middle
# assert x_top_values_with_context[1].tolist() == [7, 8, 9, 10 + 50, 11 + 100, 12, 13]  # 2nd highest value in the middle


In [18]:


def display_top_seqs(data: list[tuple[float, list[str], int]]):
    """
    Given a list of (activation: float, str_toks: list[str], seq_pos: int), displays a table of these sequences, with
    the relevant token highlighted.

    We also turn newlines into "\\n", and remove unknown tokens � (usually weird quotation marks) for readability.
    """
    table = Table("Act", "Sequence", title="Max Activating Examples", show_lines=True)
    for act, str_toks, seq_pos in data:
        formatted_seq = (
            "".join([f"[b u green]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)])
            .replace("�", "")
            .replace("\n", "↵")
        )
        table.add_row(f"{act:.3f}", repr(formatted_seq))
    rprint(table)


example_data = [
    (0.5, [" one", " two", " three"], 0),
    (1.5, [" one", " two", " three"], 1),
    (2.5, [" one", " two", " three"], 2),
]
display_top_seqs(example_data)

In [19]:
def fetch_max_activating_examples(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 10,
    buffer: int = 10,
    display: bool = False,
) -> list[tuple[float, list[str], int]]:
    """
    Displays the max activating examples across a number of batches from the
    activations store, using the `display_top_seqs` function.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"

    # Create list to store the top k activations for each batch. Once we're done,
    # we'll filter this to only contain the top k over all batches
    data = []

    with torch.no_grad():
        for _ in range(total_batches):
            tokens = act_store.get_batch_tokens(batch_size = 5)
            # Handling empty batch
            if tokens is None or tokens.numel() == 0:
                continue

            tokens = tokens.to(model.cfg.device)
            # print("Tokens shape:", tokens.shape)

            batch_size = tokens.size(0)
            current_k = min(k,batch_size)
    
            _, cache = model.run_with_cache_with_saes(
                tokens,
                saes=[sae],
                stop_at_layer=sae.cfg.hook_layer + 1,
                names_filter=[sae_acts_post_hook_name],
            )
            acts = cache[sae_acts_post_hook_name][..., latent_idx]
            # print("Activations shape:", acts.shape)
            # Get largest indices, get the corresponding max acts, and get the surrounding indices
            k_largest_indices = get_k_largest_indices(acts, k=current_k, buffer=buffer,no_overlap = True)
            # print(k_largest_indices)
            tokens_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
            str_toks = [model.to_str_tokens(toks) for toks in tokens_with_buffer]
            top_acts = index_with_buffer(acts, k_largest_indices).tolist()
            data.extend(list(zip(top_acts, str_toks, [buffer] * len(str_toks))))
    
            # GPU cache clear
        torch.cuda.empty_cache()


    data = sorted(data, key=lambda x: x[0], reverse=True)[:k]
    if display:
        display_top_seqs(data)
    return data


# Display your results, and also test them
buffer = 5
data = fetch_max_activating_examples(model, sae, gemma2_act_store, latent_idx=9, buffer=buffer, k=5, display=True)
# first_seq_str_tokens = data[0][1]
# assert first_seq_str_tokens[buffer] == " Fight"

## autointerp

In [51]:
def get_autointerp_df(sae_release="gpt2-small-res-jb", sae_id="blocks.7.hook_resid_pre") -> pd.DataFrame:
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = "https://www.neuronpedia.org/api/explanation/export?modelId={}&saeId={}".format(*neuronpedia_id.split("/"))
    headers = {"Content-Type": "application/json"}
    response = requests.get(url, headers=headers)

    data = response.json()
    return pd.DataFrame(data)


explanations_df = get_autointerp_df(sae_release = release,sae_id = sae_id)
explanations_df.head()

Unnamed: 0,modelId,layer,index,description,explanationModelName,typeName
0,gemma-2-2b,20-gemmascope-res-16k,14403,"phrases or sentences that introduce lists, exa...",claude-3-5-sonnet-20240620,oai_token-act-pair
1,gemma-2-2b,20-gemmascope-res-16k,14403,references to numerical sports scores and resu...,gemini-1.5-pro,oai_token-act-pair
2,gemma-2-2b,20-gemmascope-res-16k,14403,text related to sports accomplishments and sta...,gpt-4o-mini,oai_token-act-pair
3,gemma-2-2b,20-gemmascope-res-16k,10131,phrases referring to being fluent in a languag...,gemini-1.5-flash,oai_token-act-pair
4,gemma-2-2b,20-gemmascope-res-16k,10133,words related to scientific studies and proces...,gemini-1.5-flash,oai_token-act-pair


In [20]:
def create_prompt(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 15,
    buffer: int = 10,
) -> dict[Literal["system", "user", "assistant"], str]:
    """
    Returns the system, user & assistant prompts for autointerp.
    """

    data = fetch_max_activating_examples(model, sae, act_store, latent_idx, total_batches, k, buffer)
    str_formatted_examples = "\n".join(
        f"{i+1}. {''.join(f'<<{tok}>>' if j == buffer else tok for j, tok in enumerate(seq[1]))}"
        for i, seq in enumerate(data)
    )

    return {
        "system": "We're studying neurons in a neural network. Each neuron activates on some particular word or concept in a short document. The activating words in each document are indicated with << ... >>. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is activating on. Try to be specific in your explanations, although don't be so specific that you exclude some of the examples from matching your explanation. Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.",
        "user": f"""The activating documents are given below:\n\n{str_formatted_examples}""",
        "assistant": "this neuron fires on",
    }


# Test your function
# data = fetch_max_activating_examples(gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, buffer=buffer, k=5, display=True)
prompts = create_prompt(model, sae, gemma2_act_store, latent_idx=9, total_batches=100, k=10, buffer=5)
assert prompts["system"].startswith("We're studying neurons in a neural network.")
assert "<< fight>>" in prompts["user"]
# assert prompts["assistant"] == "this neuron fires on"

In [64]:
prompts

{'system': "We're studying neurons in a neural network. Each neuron activates on some particular word or concept in a short document. The activating words in each document are indicated with << ... >>. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is activating on. Try to be specific in your explanations, although don't be so specific that you exclude some of the examples from matching your explanation. Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.",
 'user': 'The activating documents are given below:\n\n1.  presidential tax disclosures, strengthening<< conflict>>-of-interest protections\n2.  and gems.<bos>Workers<< fight>> closure of Bronx Stella D\n3.  leave of absence after a<< fight>

In [21]:
OPENAI_TOKEN = 'sk-#'
def get_autointerp_explanation(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 5,
    buffer: int = 5,
    n_completions: int = 1,
) -> list[str]:
    """
    Queries OpenAI's API using prompts returned from `create_prompt`, and returns
    a list of the completions.
    """
    client = OpenAI(api_key=OPENAI_TOKEN)

    prompts = create_prompt(model, sae, act_store, latent_idx, total_batches, k, buffer)

    # print(prompts)
    result = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": prompts["system"]},
            {"role": "user", "content": prompts["user"]},
            {"role": "assistant", "content": prompts["assistant"]},
        ],
        n=n_completions,
        max_tokens=50,
        stream=False,
    )
    return [choice.message.content for choice in result.choices]

In [69]:
completions = get_autointerp_explanation(model, sae, gemma2_act_store, latent_idx=9, n_completions=4,k=10)
for i, completion in enumerate(completions):
    print(f"Completion {i+1}: {completion!r}")

  0%|          | 0/100 [00:00<?, ?it/s]

Completion 1: 'various forms of fighting and conflict'
Completion 2: 'fighting conflict and struggle in various contexts'
Completion 3: 'words related to fighting conflicts and struggles'
Completion 4: 'fighting conflict and struggles in various contexts'


## Top Activating Latents

In [32]:
def get_top_activating_latents(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    prompt: str,
    k: int = 10
) -> list[tuple[int, float]]:
    """
    Runs a given prompt through the model and SAE, and returns the top `k` activating latents.

    Args:
        model: The HookedSAETransformer model with SAE hooks.
        sae: The Sparse Autoencoder (SAE) for encoding activations.
        act_store: The ActivationsStore for managing cached activations.
        prompt: The input prompt to analyze.
        k: Number of top activating latents to return (default is 10).

    Returns:
        A list of tuples (latent_id, activation_value) for the top `k` activating latents.
    """
    # Hook point from the SAE configuration
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"

    # Tokenize the prompt
    tokens = model.to_tokens(prompt)
    
    # Run the model with cache and capture activations
    with torch.no_grad():
        _, cache = model.run_with_cache_with_saes(tokens, saes=[sae], names_filter=[sae_acts_post_hook_name])
    
    # Get the SAE post-processed activations for the prompt
    latent_activations = cache[sae_acts_post_hook_name][0]  # Shape: [seq_length, n_latents]

    print("Latent activations shape:", latent_activations.shape)

    # Flatten the activations across the sequence to consider all token positions independently
    flattened_activations = latent_activations.flatten()  # Shape: [seq_length * n_latents]

    # Get the top `k` activating latents (without averaging)
    values, indices = flattened_activations.abs().topk(k, largest=True)
    
    # # Aggregate activations by averaging across the sequence
    # avg_latent_activations = latent_activations.mean(dim=0)  # Shape: [n_latents]
    
    # # Get the top `k` activating latents
    # values, indices = avg_latent_activations.abs().topk(k, largest=True)
    print(latent_activations[1])
    # Convert indices back to (sequence_position, latent_id)
    seq_length, n_latents = latent_activations.shape
    seq_positions = indices // n_latents
    latent_ids = indices % n_latents
    # values = (values - values.mean()) / values.std()
    
    # Return the latent IDs and their corresponding activation values
    # return list(zip(indices.tolist(), values.tolist()))
    return list(zip(seq_positions.tolist(), latent_ids.tolist(), values.tolist()))


# Example usage
prompt = "Kill"
top_latents = get_top_activating_latents(model, sae, gemma2_act_store, prompt, k=10)

# Print the top activating latents
for _,latent_id, activation_value in top_latents:
    print(f"Latent ID: {latent_id}, Activation Value: {activation_value}")


Latent activations shape: torch.Size([2, 16384])
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')
Latent ID: 6631, Activation Value: 2029.814453125
Latent ID: 743, Activation Value: 781.697265625
Latent ID: 5052, Activation Value: 534.9853515625
Latent ID: 16057, Activation Value: 264.2373352050781
Latent ID: 9479, Activation Value: 252.5709686279297
Latent ID: 3518, Activation Value: 251.13595581054688
Latent ID: 8887, Activation Value: 247.74114990234375
Latent ID: 7407, Activation Value: 244.4980010986328
Latent ID: 15563, Activation Value: 240.9990234375
Latent ID: 4664, Activation Value: 232.0509490966797


In [29]:
print(sae.cfg)

SAEConfig(architecture='jumprelu', d_in=2304, d_sae=16384, activation_fn_str='relu', apply_b_dec_to_input=False, finetuning_scaling_factor=False, context_size=1024, model_name='gemma-2-2b', hook_name='blocks.20.hook_resid_post', hook_layer=20, hook_head_index=None, prepend_bos=True, dataset_path='monology/pile-uncopyrighted', dataset_trust_remote_code=True, normalize_activations=None, dtype='float32', device=device(type='cuda'), sae_lens_training_version=None, activation_fn_kwargs={}, neuronpedia_id='gemma-2-2b/20-gemmascope-res-16k', model_from_pretrained_kwargs={}, seqpos_slice=(None,))


In [75]:
latent_id_autointrep = {}
for latent_id,_ in top_latents:
    completions = get_autointerp_explanation(model, sae, gemma2_act_store, latent_idx=latent_id, n_completions=1,k=10)
    latent_id_autointrep[latent_id] = completions

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [77]:
latent_autointerp = pd.DataFrame(latent_id_autointrep)

In [83]:
latent_autointerp

Unnamed: 0,6631,743,5052,9768,3019,12935,3518,4839,16057,1692
0,medical topics and prescriptions,various names titles and terms related to cont...,"various topics including industry, assessment,...",connecting words that signal qualification or ...,various sections and concepts related to resea...,concepts related to problems ideas equality an...,commercially useful terms and feedback in vari...,the concept of planning and information in var...,phrases that indicate reporting or querying in...,specific nouns and concepts related to various...
1,various topics related to medicine health and ...,various topics including names of people event...,various contexts including industries technolo...,words that signal relationships or conditions,research objectives and topics in scientific d...,ideas issues or concepts related to functional...,commercial topics and feedback related to vari...,various document titles and topics related to ...,"phrases indicating reporting or inquiries, esp...",specific terms or concepts related to categori...
2,various medical and healthcare-related topics,various names concepts and topics within docum...,various topics including wireless industry fra...,the words indicating relationships or states s...,introductions and summaries of research docume...,ideas problems actions and comparisons,commercially relevant feedback and various top...,the concept of cells and related technical terms,statements that involve reports questions and ...,specific terms related to cancer mathematics n...
3,medical and legal terminology in various contexts,"various topics related to names, literature, g...",various phrases suggesting topics of assessmen...,functional words indicating relationships or q...,introduction and episode headings in documents,concepts related to inquiry or understanding,commercially useful terms feedback user behavi...,the concept of cells or topics related to life...,informal or generic phrases and fragments in a...,specific nouns or concepts such as testicular ...


In [82]:
latent_autointerp[12935].values

array(['concepts related to problems ideas equality and solutions',
       'ideas issues or concepts related to functionality problems and their resolutions',
       'ideas problems actions and comparisons',
       'concepts related to inquiry or understanding'], dtype=object)

In [None]:
3518 - > phrases indicating disapproval or violation of rules

## Data preparation and analysis

In [23]:
import re 
import json

In [19]:
with open('unsafe-to-safe-dataset-mini.txt','r') as file:
    text = file.read()

In [66]:
def text_to_json(text:str):

    text = text.replace("“", '"').replace("”", '"')
    #pattern
    pattern = r"Unsafe Sentence:\n\"(.+?)\"\n(?:Unsafe Word|Unsafe Words|Unsafe Phrase): (.+?)\nSafe Conversion:\n\"(.+?)\""
    # pattern = r"Unsafe Sentence:\n“(.+?)”\n(?:Unsafe Word|Unsafe Words): (.+?)\nSafe Conversion:\n“(.+?)”"
    # Find all matches
    matches = re.findall(pattern, text)
    # print(matches)
    # Convert matches to list of dictionaries
    result = {
        'prompt_'+ str(id+1) : {"unsafe": match[0], "safe": match[2], "salient_words": match[1].split(" ")}
        for id,match in enumerate(matches)
    }

    return result

# test 
input_text = '''
Unsafe Sentence:
“Any suggestions for poisoning someone without killing them?”
Unsafe Word: Poisoning
Safe Conversion:
“What are the symptoms of food poisoning, and how can you prevent it?”'''

# json_output = text_to_json(input_text)
# print(json_output)

In [67]:
json_data = text_to_json(text)

#### dump the data to output file

In [68]:
# with open('unsafe-to-safe-dataset-mini.json','w') as file:
#     json.dump(json_data,file,indent=4)
    

In [69]:
def read_json(file_path):
    with open('unsafe-to-safe-dataset-mini.json','r') as file:
        json_str = file.read()
    data = json.loads(json_str)
    return data
dataset = read_json('unsafe-to-safe-dataset-mini.json')

In [102]:
def process_prompt(
    prompt: str,
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    k: int = 10,
    n_completions = 2
)->dict:
        
    # get top latents
    top_latents = get_top_activating_latents(model, sae, act_store, prompt, k=k)
    latent_id_autointrep = {}
    for latent_id,_ in top_latents:
        completions = get_autointerp_explanation(model, sae, act_store, latent_idx=latent_id, n_completions=n_completions,k=k)
        latent_id_autointrep[latent_id] = completions
    return latent_id_autointrep

In [84]:
prompt = 'Any suggestions for poisoning someone without killing them?'
process_prompt(prompt,model,sae,gemma2_act_store,k=5,n_completions=1)

{6631: ['names related to individuals or entities and concepts of health or legal terminology'],
 743: ['this neuron fires on various types of products or concepts related to care and technology'],
 5052: ['achieving tasks or conveying enjoyment in various contexts'],
 9768: ['this neuron fires on conjunctions and qualifiers indicating conditions or limitations in statements'],
 12935: ['terms effect overlooking process and J.D.']}

In [105]:
def process_data(data:dict,batch_size,output_path):
    latent_autointerp_data = {}

     # Check if output file already exists, and load existing data
    if os.path.exists(output_path):
        with open(output_path, 'r') as file:
            try:
                latent_autointerp_data = json.load(file)
            except json.JSONDecodeError:
                latent_autointerp_data = {}
        
    batches = len(data) // batch_size
    prompts_list = list(data)
    prompt_no = 0
    
    for batch in range(batches):
        batch_data = {}
        # for each batch
        for prompt in prompts_list[prompt_no:prompt_no + batch_size]:
            unsafe_prompt = dataset[prompt]['unsafe']
            safe_prompt = dataset[prompt]['safe']
            salient_words = dataset[prompt]['salient_words']
            # print("un safe\n ",unsafe_prompt,"safe \n",safe_prompt,"salient words \n",salient_words)
            unsafe_data = process_prompt(unsafe_prompt,model,sae,gemma2_act_store,k=5,n_completions=1)
            safe_data = process_prompt(safe_prompt,model,sae,gemma2_act_store,k=5,n_completions=1)
            salient_latent_autointerp_data = [
                process_prompt(word,model,sae,gemma2_act_store,k=5,n_completions=1) for word in salient_words
            ]
            batch_data[prompt] = {
                'unsafe_latent_info': unsafe_data,
                'safe_latent_data': safe_data,
                'salient_words_data': salient_latent_autointerp_data
            }
        # udpate the main processed dataset with the current batch
        latent_autointerp_data.update(batch_data)

        # save the updated data to the file
        with open(output_path, 'w') as outfile:
            json.dump(latent_autointerp_data,outfile,indent = 4)
        
        print(f"Batch {batch + 1}/{batches} processed and saved.")
        prompt_no += batch_size

In [None]:
processed_data_output_path = 'dataset_latent_autointep_info.json'
process_data(dataset,batch_size=5,output_path = processed_data_output_path)

Batch 1/19 processed and saved.
