# Setup

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
!pip install sae_lens

In [None]:
!pip install Datasets

In [None]:
from datasets import load_dataset

ds = load_dataset("google-research-datasets/go_emotions", "simplified")

In [None]:
from datasets import concatenate_datasets

ds = concatenate_datasets([ds['train'], ds['validation'], ds['test']])
ds

In [None]:
label_names = ds.features["labels"].feature.names


In [None]:
targeted_emotions = ['joy', 'anger', 'disgust', 'sadness', 'love', 'fear', 'excitement']
labels = []
for em in targeted_emotions:
  labels.append(label_names.index(em))

In [None]:
ds = ds.filter(lambda x: any(label in labels for label in x['labels']))

In [None]:
ds = ds.filter(lambda x: len(x['labels']) == 1)

In [None]:
import numpy as np

ids = np.load("ds_filt.npy")

In [None]:
filtered_ds = ds.filter(lambda x: x["id"] in ids)


In [None]:
ds = filtered_ds

In [None]:
len(ds)

In [None]:
ds = ds['train']


In [None]:
ds

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
#model = AutoModel.from_pretrained("google/gemma-2-2b", output_hidden_states=True)
#model.eval()


In [None]:
!pip install sae_lens

In [None]:
pip install tabulate

In [None]:
from tabulate import tabulate

In [None]:
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
#from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig

In [None]:
def format_value(value):
    return "{{{0!r}: {1!r}, ...}}".format(*next(iter(value.items()))) if isinstance(value, dict) else repr(value)


release = get_pretrained_saes_directory()["gemma-scope-2b-pt-res"]

print(
    tabulate(
        [[k, format_value(v)] for k, v in release.__dict__.items()],
        headers=["Field", "Value"],
        tablefmt="simple_outline",
    )
)

In [None]:
data = [[id, path, release.neuronpedia_id[id]] for id, path in release.saes_map.items()]

print(
    tabulate(
        data,
        headers=["SAE id", "SAE path (HuggingFace)", "Neuronpedia ID"],
        tablefmt="simple_outline",
    )
)

# Collecting activations

In [None]:
texts = [item["text"] for item in ds]
labels = [item["labels"] for item in ds]


In [None]:
import torch
from sae_lens import SAE, ActivationsStore

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

gemma_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res",
    sae_id="layer_20/width_16k/average_l0_71",
    device=str(device),
)

gemma = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device=device)


In [None]:
label_names = load_dataset("go_emotions", "simplified", split="train").features["labels"].feature.names
#sad_id = label_names.index("sadness")

In [None]:
targeted_emotions = ['joy', 'anger', 'disgust', 'sadness']
labels = []
for em in targeted_emotions:
  labels.append(label_names.index(em))

In [None]:
ds_reduced = ds.filter(lambda x: any(label in labels for label in x['labels']))

In [None]:
texts = [item['text'] for item in ds_reduced ]
labels = [item['labels'] for item in ds_reduced ]

In [None]:
#just to measure length
lengths = [len(tokenizer.encode(text)) for text in texts]
max_length = max(lengths)
print(f"Max tokenized length across all samples: {max_length}")


In [None]:
from tqdm import tqdm
import torch
import gc

batch_size = 4
max_length = 52
results = []

for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
    try:
        gc.collect()
        torch.cuda.empty_cache()

        # Get batch
        batch_texts = texts[i:i + batch_size]

        # Tokenize
        tokenized = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        )
        input_ids = tokenized["input_ids"].to(device)

        # Forward pass with SAE and cache activations
        _, cache = gemma.run_with_cache_with_saes(
            input_ids,
            saes=[gemma_sae],
            stop_at_layer=gemma_sae.cfg.hook_layer + 1,
            names_filter=[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"],
        )

        # SAE activations (features)
        sae_acts = cache[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"]  # [B, T, F]
        final_acts = sae_acts[:, -1, :].detach().cpu()

        # Sparsity = number of active features
        sparsity = (sae_acts[:, -1, :] > 1).sum(dim=-1)  # [B]

        decoded_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids]


        for i in range(len(batch_texts)):
            results.append({
                "input_ids": input_ids[i].detach().cpu(),              # torch.Tensor
                "tokens": decoded_tokens[i],                          # list of strings
                "activation": final_acts[i],                          # torch.Tensor
                "sparsity": int(sparsity[i]),                         # int
            })

        # Cleanup
        del cache, sae_acts, final_acts, sparsity, input_ids, tokenized
        torch.cuda.empty_cache()

    except RuntimeError as e:
        print(f"⚠️ OOM on batch {i}-{i+batch_size}: {e}")
        torch.cuda.empty_cache()


In [None]:
import pickle

with open("sae_results_filtered.pkl", "wb") as f:
    pickle.dump(results, f)


In [None]:
print("Average sparsity:", torch.tensor([r["sparsity"] for r in results]).float().mean())


In [None]:
import matplotlib.pyplot as plt

sparsities = [r["sparsity"] for r in results]
plt.hist(sparsities, bins=50)
plt.title("Sparsity Distribution")
plt.xlabel("# of active SAE features (>1)")
plt.ylabel("Frequency")
plt.show()


In [None]:
topk = results[0]["activation"].topk(10)

top_neurons = topk.indices.tolist()
top_values = topk.values.tolist()

print(f"Top firing SAE neurons and their activations for input:{results[0]['text']}")
for idx, val in zip(top_neurons, top_values):
    print(f"Neuron {idx} ➝ Activation: {val:.4f}")


In [None]:
inds = []

for i, entry in enumerate(results):
    text = entry["text"]
    activation = entry["activation"]

    val, idx = activation.max(-1)
    inds.append(idx)


In [None]:
import plotly.express as px
import torch

sample = results[0]

activation = sample["activation"]
activation_np = activation.numpy()

px.line(
    y=activation_np,
    title=f"SAE Activations for Sample 0 — Final Token",
    labels={"index": "Neuron (Latent Feature)", "value": "Activation"},
    width=1000
).update_layout(showlegend=False).show()


# Drafts

In [None]:
results = [
    {
        "text": decoded_prompts[i],
        "activation": final_acts[i].detach().cpu()
    }
    for i in range(len(texts))
]


In [None]:
'''
inappropriate as it will concatenate differently labeled text into one chunk

from transformer_lens.utils import tokenize_and_concatenate

token_dataset = tokenize_and_concatenate(
    dataset=ds,  # type: ignore
    tokenizer=tokenizer,  # type: ignore
    streaming=True,
    max_length=gemma_sae.cfg.context_size,
    add_bos_token=gemma_sae.cfg.prepend_bos,
)

In [None]:
from datasets import load_dataset
from sae_lens import SAE, ActivationsStore
from transformer_lens.utils import tokenize_and_concatenate
from transformers import AutoTokenizer
from torch.nn.utils.rnn import pad_sequence
import torch
'''
tokenized_ds = tokenize_and_concatenate(
    dataset=ds.remove_columns([col for col in ds.column_names if col != "text"]),
    tokenizer=tokenizer,
    column_name="text",
    streaming=False,
    max_length=gemma_sae.cfg.context_size,
    add_bos_token=gemma_sae.cfg.prepend_bos,
)
'''

# Collect activations
batch_size = 2
all_acts, all_labels = [], []

for i in range(0, len(tokenized_ds), batch_size):
    batch = tokenized_ds.select(range(i, min(i + batch_size, len(tokenized_ds))))
    tokens = batch["tokens"]  # ✅ This is now a list of LongTensors
    input_ids = pad_sequence(tokens, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)

    try:
        _, cache = gemma.run_with_cache_with_saes(
            input_ids,
            saes=[gemma_sae],
            stop_at_layer=gemma_sae.cfg.hook_layer + 1,
        )
        sae_acts = cache[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"]
        final_acts = sae_acts[:, -1, :].detach().cpu()

        all_acts.append(final_acts)
        all_labels.extend(labels[i:i + len(tokens)])

        del cache, sae_acts, final_acts, input_ids
        torch.cuda.empty_cache()

    except RuntimeError as e:
        print(f"Skipping batch {i}-{i+batch_size} due to OOM: {e}")
        torch.cuda.empty_cache()




In [None]:
pip install ace_tools

In [None]:
X = torch.cat(all_acts, dim=0)
y = all_labels

label_to_acts = defaultdict(list)
for xi, yi in zip(X, y):
    for label in yi:
        label_to_acts[label].append(xi)

label_to_mean = {label: torch.stack(acts).mean(dim=0) for label, acts in label_to_acts.items()}

# Output as DataFrame
import pandas as pd
df_mean = pd.DataFrame.from_dict({k: v.numpy() for k, v in label_to_mean.items()}, orient="index")

display(df_mean)

In [None]:
top_k = 10
emotion_to_top_neurons = {}

for label, mean_acts in label_to_mean.items():
    top_values, top_indices = mean_acts.topk(top_k)
    emotion_to_top_neurons[label] = list(zip(top_indices.tolist(), top_values.tolist()))

# Convert to readable DataFrame
emotion_top_df = pd.DataFrame.from_dict(emotion_to_top_neurons, orient="index")
emotion_top_df.columns = [f"TOP {i+1} Neuron" for i in range(top_k)]


In [None]:
emotion_top_df

In [None]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# this function should open
neuronpedia_quick_list = get_neuronpedia_quick_list(gemma_sae, test_feature_idx_gpt)

if COLAB:
    # If you're on colab, click the link below
    print(neuronpedia_quick_list)

In [None]:
from collections import defaultdict

label_to_acts = defaultdict(list)
for xi, yi in zip(X, y):
    label_to_acts[yi].append(xi)

label_to_mean = {label: torch.stack(acts).mean(dim=0) for label, acts in label_to_acts.items()}


In [None]:
from sae_lens.analysis.neuronpedia_integration import
# Example: interpret neuron 1234 from your SAE
neuron_id = 1234
interpretation = neuronpedia.analyze_feature(
    feature_idx=neuron_id,
    sae=gemma_sae,
    model=gemma,
    top_k=15  # get top associated tokens or completions
)

In [None]:
for name, param in cache.items():
    if "hook_sae" in name:
        print(f"{name:<43}: {tuple(param.shape)}")

In [None]:
import plotly.express as px

In [None]:
# Plot line chart of latent activations
px.line(
    sae_acts_post.cpu().numpy(),
    title=f"Latent activations at the final token position ({sae_acts_post.nonzero().numel()} alive)",
    labels={"index": "Latent", "value": "Activation"},
    width=1000,
).update_layout(showlegend=False).show()

# Print the top 5 latents, and inspect their dashboards
for act, ind in zip(*sae_acts_post.topk(3)):
    print(f"Latent {ind} had activation {act:.2f}")
    display_dashboard(latent_idx=ind)

In [None]:
sae_acts_post_hook_name = f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"
all_positive_acts = []

for i in tqdm(range(total_batches)):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[gemma_sae],
            stop_at_layer=gemma_sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]
        all_positive_acts.extend(acts[acts > 0].cpu().tolist())

### Neuronopedia

In [None]:
results[0]

In [None]:
sae_release = "gemma-2-2b"
sae_id = "19-gemmascope-res-16k"
feature_idx = 11882  # example neuron

url = f"https://neuronpedia.org/{sae_release}/{sae_id}/{feature_idx}"

In [None]:
for neuron in top_neurons:
    print(f"Neuron {neuron} ➝ https://neuronpedia.org/gemma-2-2b/19-gemmascope-res-16k__l0-137/2725")


In [None]:
from IPython.display import IFrame

IFrame("https://neuronpedia.org/gemma-2-2b/19-gemmascope-res-16k__l0-137/2725", width=1200, height=600)


In [None]:
Top firing SAE neurons and their activations for input:I miss them being alive
Neuron 15509 ➝ Activation: 47.4543
Neuron 4326 ➝ Activation: 35.8033
Neuron 14232 ➝ Activation: 34.3398
Neuron 204 ➝ Activation: 29.5415
Neuron 15328 ➝ Activation: 27.9000
Neuron 11864 ➝ Activation: 26.3788
Neuron 1692 ➝ Activation: 26.2759
Neuron 9768 ➝ Activation: 25.3538
Neuron 15539 ➝ Activation: 23.9061
Neuron 14084 ➝ Activation: 23.6004

In [None]:
# for layer 20 from gemma scope tutorial


from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=15539)
IFrame(html, width=1200, height=600)

In [None]:
def fetch_max_activating_examples(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 10,
    buffer: int = 10,
    display: bool = False,
) -> list[tuple[float, list[str], int]]:
    """
    Displays the max activating examples across a number of batches from the
    activations store, using the `display_top_seqs` function.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"

    # Create list to store the top k activations for each batch. Once we're done,
    # we'll filter this to only contain the top k over all batches
    data = []

    for _ in tqdm(range(total_batches)):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]

        # Get largest indices, get the corresponding max acts, and get the surrounding indices
        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        tokens_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
        str_toks = [model.to_str_tokens(toks) for toks in tokens_with_buffer]
        top_acts = index_with_buffer(acts, k_largest_indices).tolist()
        data.extend(list(zip(top_acts, str_toks, [buffer] * len(str_toks))))

    data = sorted(data, key=lambda x: x[0], reverse=True)[:k]
    if display:
        display_top_seqs(data)
    return data


# Display your results, and also test them
buffer = 10
data = fetch_max_activating_examples(gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, buffer=buffer, k=5, display=True)
first_seq_str_tokens = data[0][1]
assert first_seq_str_tokens[buffer] == " new"