##### Imports

In [1]:
import os
import warnings

import numpy as np
import pandas as pd
import plotly.express as px
import torch
from datasets import load_dataset
from IPython.display import IFrame
from jaxtyping import Float, Int
from tabulate import tabulate
from torch import Tensor
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import test_prompt, tokenize_and_concatenate

from e2e_sae import SAETransformer
from e2e_sae.data import create_data_loader

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: mps


In [3]:
class DotDict(dict):
    """A dictionary that supports dot notation."""

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"'DotDict' object has no attribute '{key}'")

    def __setattr__(self, key, value):
        self[key] = value

    def __delattr__(self, key):
        try:
            del self[key]
        except KeyError:
            raise AttributeError(f"'DotDict' object has no attribute '{key}'")

##### Load Model

In [4]:
with warnings.catch_warnings(action="ignore"):
    model = SAETransformer.from_wandb("sparsify/gpt2/tvj2owza")
    model.to(device)

transformer = model.tlens_model
saes_dict = model.saes

sae_pos = model.raw_sae_positions[0]
sae = saes_dict["blocks-6-hook_resid_pre"]
d_sae = sae.encoder[0].out_features

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps


##### Neuronpedia Dashboard

In [5]:
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"


def get_dashboard_html(sae_release="gpt2-small", sae_id="6-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

##### Basic Config

In [6]:
dataset_config = DotDict(
    {
        "dataset_name": "NeelNanda/pile-10k",
        "is_tokenized": False,
        "tokenizer_name": "gpt2",
        "streaming": True,
        "split": "train",
        "n_ctx": 1024,
        "seed": 0,
    }
)

In [17]:
dataloader, _ = create_data_loader(dataset_config=dataset_config, batch_size=16)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



##### Basic Test Prompt

In [8]:
prompt = "The next person in line is singer Johnny"
answer = "Cash"

# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, transformer)

Tokenized prompt: ['<|endoftext|>', 'The', ' next', ' person', ' in', ' line', ' is', ' singer', ' Johnny']
Tokenized answer: [' Cash']


Top 0th token. Logit: 17.33 Prob: 38.65% Token: | Cash|
Top 1th token. Logit: 15.72 Prob:  7.80% Token: | De|
Top 2th token. Logit: 14.47 Prob:  2.21% Token: | B|
Top 3th token. Logit: 14.34 Prob:  1.95% Token: | Mar|
Top 4th token. Logit: 13.67 Prob:  1.00% Token: | "|
Top 5th token. Logit: 13.66 Prob:  0.99% Token: | H|
Top 6th token. Logit: 13.62 Prob:  0.95% Token: | G|
Top 7th token. Logit: 13.43 Prob:  0.79% Token: | R|
Top 8th token. Logit: 13.43 Prob:  0.78% Token: | Carson|
Top 9th token. Logit: 13.40 Prob:  0.77% Token: | Mercer|


##### Max Activating Feature

In [9]:
out, cache = model.forward(prompt, sae_positions=model.raw_sae_positions, cache_positions=None)

In [10]:
px.line(
    cache[sae_pos].output[0, -1, :].cpu().numpy(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
).show()

# let's print the top 5 features and how much they fired
vals, inds = torch.topk(cache[sae_pos].output[0, -1, :], 5)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired {val:.2f}")

Feature 447 fired 17.32
Feature 326 fired 9.92
Feature 266 fired 7.84
Feature 288 fired 7.00
Feature 481 fired 6.27


In [11]:
FEATURE_IDX = 20986

In [24]:
def show_top_logits(
    transformer,
    sae,
    feature_idx: int,
    k: int = 10,
) -> None:
    """
    Displays the top & bottom logits for a particular feature.
    """
    logits = sae.decoder.weight[:, feature_idx] @ transformer.W_U

    top_logits, top_token_ids = logits.topk(10)
    top_tokens = transformer.to_str_tokens(top_token_ids)
    bottom_logits, bottom_token_ids = logits.topk(10, largest=False)
    bottom_tokens = transformer.to_str_tokens(bottom_token_ids)

    print(
        tabulate(
            zip(map(repr, bottom_tokens), bottom_logits, map(repr, top_tokens), top_logits),
            headers=["Bottom Tokens", "Logits", "Top Tokens", "Logits"],
            tablefmt="simple_outline",
            stralign="right",
            floatfmt="+.4f",
            showindex=True,
        )
    )


show_top_logits(transformer, sae, feature_idx=FEATURE_IDX)

┌────┬─────────────────┬──────────┬────────────────┬──────────┐
│    │   Bottom Tokens │   Logits │     Top Tokens │   Logits │
├────┼─────────────────┼──────────┼────────────────┼──────────┤
│  0 │        'course' │  -0.8429 │          'ocl' │  +1.0050 │
│  1 │         'cible' │  -0.8277 │ ' prominently' │  +0.9103 │
│  2 │ ' Intermediate' │  -0.7800 │      ' plaque' │  +0.8891 │
│  3 │    'efficients' │  -0.7011 │      ' etched' │  +0.8670 │
│  4 │           'jab' │  -0.6901 │    ' engraved' │  +0.8527 │
│  5 │           'rir' │  -0.6891 │       'stones' │  +0.8374 │
│  6 │           'yon' │  -0.6868 │        ' tarn' │  +0.8200 │
│  7 │         'erate' │  -0.6848 │     ' artwork' │  +0.8182 │
│  8 │          'prov' │  -0.6797 │   ' paintings' │  +0.8088 │
│  9 │         'Princ' │  -0.6567 │     ' tattoos' │  +0.8044 │
└────┴─────────────────┴──────────┴────────────────┴──────────┘


In [23]:
def get_k_largest_indices(
    x: Float[Tensor, "batch seq"], k: int, buffer: int | None = 5
) -> Int[Tensor, "k 2"]:
    """
    The indices of the top k elements in the tensor x. In other words, output[i, :] is the (batch, seqpos) values of the i-th largest element in x.

    Won't choose any elements within `buffer` from the start or end of their sequence.
    """
    if buffer is None:
        buffer = 0
    x = x[:, buffer:-buffer]
    indices = x.flatten().topk(k=k).indices
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer
    return torch.stack((rows, cols), dim=1)


def show_max_activating_examples(
    model: SAETransformer,
    feature_idx: int,
    total_batches: int = 100,
    k: int = 10,
) -> None:
    """
    Displays the max activating examples across a number of batches from the activations store.
    """
    buffer = 10

    # Create list to store the top k activations for each batch. Once we're done, we'll filter this
    # to only contain the top k over all batches
    data = []

    dl_iter = iter(dataloader)
    for i in tqdm(range(total_batches)):
        tokens = next(dl_iter)["input_ids"].to(device)
        logits, activations = model.forward(tokens=tokens, sae_positions=model.raw_sae_positions)
        feature_acts = activations[sae_pos].c
        acts = feature_acts[..., feature_idx]

        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        for b, s in k_largest_indices:
            str_toks_with_context = model.tlens_model.to_str_tokens(
                tokens[b, s - buffer : s + buffer]
            )
            str_toks_with_context = ["\\n" if tok == "\n" else tok for tok in str_toks_with_context]
            formatted_seq = "".join(
                [
                    f"|{str_tok}|" if i == buffer else str_tok
                    for i, str_tok in enumerate(str_toks_with_context)
                ]  # type: ignore
            )
            data.append([acts[b, s], formatted_seq])

    print(
        tabulate(
            sorted(data, key=lambda x: x[0], reverse=True)[:k],
            headers=["Top Activation", "Example"],
            tablefmt="simple_outline",
            floatfmt="+.3f",
        )
    )


show_max_activating_examples(model, feature_idx=FEATURE_IDX, total_batches=1)

100%|██████████| 1/1 [01:21<00:00, 81.92s/it]

┌──────────────────┬───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│   Top Activation │ Example                                                                                                               │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│           +4.174 │ histories, national event reenactments, or| monuments| and statutes — are registered, apologists for                  │
│           +3.039 │ City, said, ��The resolution is a| symbol| with power and meaning in acknowledging a wrong done                       │
│           +2.460 │ bustling metropolis�� status as a living design| museum| that no doubt appeals to most foreigners. For                │
│           +2.371 │ a measurement in InfluxDB, I see time| stamps| such as:\n1491030000000000000\n                                        │
│           +


