# Loading and Analysing Pre-Trained Sparse Autoencoders

## Set Up (Just Run / Not Important)

In [None]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
from pprint import pp as print

# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False);

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# Loading a pretrained Sparse Autoencoder

Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface. To see which SAEs are available, we can use the following function from SAE Lens


In [None]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()})
df.T

In [None]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained("gpt2-small", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.7.hook_resid_pre", # won't always be a hook point
    device = device
)

The "sae" object is an instance of the SAE (Sparse Autoencoder class). There are many different SAE architectures which may have different weights or activation functions. In order to simplify working with SAEs, SAE Lens handles most of this complexity for you. 

Let's look at the SAE config and understand each of the parameters:
1. `architecture`
2. `d_in`

TODO: Finish writing this

In [None]:
print(sae.cfg.__dict__)

Next we need to load in a dataset to work with.

In [None]:
from datasets import load_dataset  
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model.tokenizer, # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

# Basics: What are SAE Features?

## Opening a feature dashboard on Neuronpedia

In [None]:
from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=feature_idx)
IFrame(html, width=1200, height=600)

What if we wanted to search for a feature relating to a specific thing? Then we could use the explanation search api. Let's just download all the explanations for these SAE features and load them in as a pandas dataframe. The Neuronpedia API docs will be useful here: https://www.neuronpedia.org/api-doc#tag/explanations/POST/api/explanation/export.

In [None]:
import requests

url = "https://www.neuronpedia.org/api/explanation/export"

payload = {
    "modelId": "gpt2-small",
    "saeId": "7-res-jb",
}
headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)


In [None]:


# convert to pandas
explanations_df = pd.DataFrame(response.json()["explanations"])
# rename index to "feature"
explanations_df.rename(columns={"index": "feature"}, inplace=True)
# explanations_df["feature"] = explanations_df["feature"].astype(int)
explanations_df["description"] = explanations_df["description"].apply(lambda x: x.lower())
explanations_df

Let's search for a feature to do with the bible. 

In [None]:
bible_features = explanations_df.loc[explanations_df.description.str.contains(" bible")]
bible_features

In [None]:
# Let's get the
html = get_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=bible_features.feature.values[0])
IFrame(html, width=1200, height=600)

# Basics: Getting Features Using SAEs

Let's start by finding the features which fire on a specific prompt.

In [None]:
from transformer_lens.utils import test_prompt

prompt = "In the beginning, God created the heavens and the"
answer = "earth"

test_prompt(prompt, answer, model)

## Method 1: Using HookedSAETransformer

In [None]:
sae.use_error_term # If use error term is set to false, we will modify the forward pass by using the sae.

In [None]:
# hooked SAE Transformer will enable us to get the feature activations from the SAE
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])

print([(k, v.shape) for k,v in cache.items() if "sae" in k])

# note there were 11 tokens in our prompt, the residual stream dimension is 768, and the number of SAE features is 768

In [None]:
# let's look at which features fired at layer 8 at the final token position
px.line(
    cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
)

# let's print the top 5 features and how much they fired
vals, inds = torch.topk(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :], 5)

for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired {val:.2f}")
    html = get_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))

### The Contrast Pairs Trick

Sometimes we may be interested in which features fire differently between two prompts.

In [None]:
from transformer_lens.utils import test_prompt

prompt = "In the beginning, God created the cat and the"
answer = "earth"

test_prompt(prompt, answer, model)

In [None]:
prompt = ["In the beginning, God created the heavens and the", "In the beginning, God created the cat and the"]
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
print([(k, v.shape) for k,v in cache.items() if "sae" in k])

feature_activation_df = pd.DataFrame(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
                                     index = [f"feature_{i}" for i in range(sae.cfg.d_sae)],
)
feature_activation_df.columns = ["heavens_and_the"]
feature_activation_df["cat_and_the"] = cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][1, -1, :].cpu().numpy()
feature_activation_df["diff"]= feature_activation_df["heavens_and_the"] - feature_activation_df["cat_and_the"]

fig = px.line(
    feature_activation_df,
    title="Feature activations for the prompt",
    labels={"index": "Feature", "value": "Activation"},
)

# hide the x-ticks
fig.update_xaxes(showticklabels=False)
fig.show()

In [None]:
# let's look at the biggest features in terms of absolute difference

diff = cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][1, -1, :].cpu() - cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu()
vals, inds = torch.topk(torch.abs(diff), 5)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} had a difference of {val:.2f}")
    html = get_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))

# Making Feature Dashboards

Basic feature dashboards have 4 main components:
1. Feature Activation Distribution. We report the proportion of tokens a feature fires on, usually between 1 in every 100 and 1 in every 10,000 tokens activations., and show the distribution of positive activations.  
2. Logit weight distribution. This is the projection of the decoder weight onto the unembed and roughly gives us a sense of the tokens promoted by a feature. It's less useful in big models / middle layers.
3. The top 10 and bottom 10 tokens in the logit weight distribution. 
4. **Max Activating Examples**. These are examples of text where the feature fires and usually provide the most information for helping us work out what a feature means. 

*Neuronpedia* is a website that hosts feature dashboards and which runs servers that can run the model and check feature activations. This makes it very convenient to check that a feature fires on the distribution of text you actually think it should fire on.

## Local: Finding Max Activating Examples

In [None]:
# instantiate an
from sae_lens import ActivationsStore

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)


In [None]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]

def make_token_df(tokens, len_prefix=5, len_suffix=1, model = model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]
    
    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        prompt=prompt,
        pos=pos,
        label=label,
    ))

In [None]:

# finding max activating examples is a bit harder. To do this we need to calculate feautre
feature_list = torch.randint(0, sae.cfg.d_sae, (100,))
examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i
    
    flat_tokens = tokens.flatten()
    
    _, cache = model.run_with_cache(tokens, stop_at_layer = sae.cfg.hook_layer + 1, names_filter = [sae.cfg.hook_name])
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0,1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)
    
    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")
    
# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)

## Getting Feature Activation Histogram

In [None]:
feature_acts_df = pd.DataFrame(all_feature_acts.cpu().numpy(), columns = [f"feature_{i}" for i in feature_list])
# all_token_dfs = pd.merge(all_token_dfs, feature_acts_df, left_index=True, right_index=True, suffixes=("", ""))
# all_token_dfs
feature_acts_df.shape

In [None]:
feature_idx = 0
# get non-zero activations

all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][:, feature_idx]
prop_positive_activations = 100*len(all_positive_acts) / (total_batches*batch_size_tokens)

px.histogram(
    all_positive_acts.cpu(),
    nbins=50,
    title=f"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive",
    labels={"value": "Activation"},
    width=800,)

In [None]:
top_10_activations = feature_acts_df.sort_values(f"feature_{feature_list[0]}", ascending=False).head(10)
all_token_dfs.iloc[top_10_activations.index]

## Getting the Top 10 Logit Weights

In [None]:
print(f"Shape of the decoder weights {sae.W_dec.shape})")
print(f"Shape of the model unembed {model.W_U.shape}")
projection_matrix = sae.W_dec @ model.W_U
print(f"Shape of the projection matrix {projection_matrix.shape}")

# then we take the top_k tokens per feature and decode them
top_k = 10
# let's do this for 100 random features
_, top_k_tokens = torch.topk(projection_matrix[feature_list], top_k, dim=1)


feature_df = pd.DataFrame(top_k_tokens.cpu().numpy(), index = [f"feature_{i}" for i in random_features]).T
feature_df.index = [f"token_{i}" for i in range(top_k)]
feature_df.applymap(lambda x: model.tokenizer.decode(x))

## Putting it all together: Compare against the Neuronpedia Dashboard

In [None]:
html = get_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=feature_list[0])
IFrame(html, width=1200, height=600)

# Advanced: Co-occurence Networks and Irriducible Subspaces

Since we just wrote code very similar to the code we need for reproducing some of the analysis from "Not All Language Models are Linear", we show below how to regenerate their awesome circular representation.

In [None]:
# only valid for res-jb resid_pre 7. 
day_of_the_week_features = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185]
# months_of_the_year = [3977, 4140, 5993, 7299, 9104, 9401, 10449, 11196, 12661, 14715, 17068, 17528, 19589, 21033, 22043, 23304]
# years_of_10th_century = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]

feature_list = day_of_the_week_features

examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i
    
    flat_tokens = tokens.flatten()
    
    _, cache = model.run_with_cache(tokens, stop_at_layer = sae.cfg.hook_layer + 1, names_filter = [sae.cfg.hook_name])
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0,1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)
    
    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")
    
# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)

In [None]:
# do PCA on reconstructions
from sklearn.decomposition import PCA
import plotly.express as px 

pca = PCA(n_components=3)
pca_embedding = pca.fit_transform(all_reconstructions.cpu().numpy())

pca_df = pd.DataFrame(pca_embedding, columns=["PC1", "PC2", "PC3"])
pca_df["tokens"] = all_fired_tokens
pca_df["context"] = all_token_dfs.context.values


px.scatter(
    pca_df, x="PC2", y="PC3",
    hover_data=["context"],
    hover_name="tokens",
    height = 800,
    width = 1200,
    color = "tokens",
    title = "PCA Subspace Reconstructions",
).show()



# Basics: Intervening on SAE Features

## Feature Ablation

## Feature Steering

In [None]:
import requests
import numpy as np

url = "https://www.neuronpedia.org/api/steer"

payload = {
    "prompt": "A knight in shining",
    # "prompt": "He had to fight back in self-", 
    # "prompt": "In the middle of the universe is the galactic",
    # "prompt": "In the middle of the universe is the galactic",
    # "prompt": "Oh no. We're running on empty. Its time to fill up the car with",
    # "prompt": "Sure, I'm happy to pay. I don't have any cash on me but let me write you a",
    "modelId": "gpt2-small",
    "features": [
        {
            "modelId": "gpt2-small",
            "layer": "7-res-jb",
            "index": 6770,
            "strength": 8
        }
    ],
    "temperature": 0.2,
    "n_tokens": 2,
    "freq_penalty": 1,
    "seed": np.random.randint(100),
    "strength_multiplier": 4
}
headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)

print(response.json())

In [None]:
import requests
import numpy as np
url = "https://www.neuronpedia.org/api/steer"

payload = {
    "prompt": "I wrote a letter to my girlfiend. It said \"",
    "modelId": "gpt2-small",
    "features": [
        {
            "modelId": "gpt2-small",
            "layer": "7-res-jb",
            "index": 20115,
            "strength": 4
        }
    ],
    "temperature": 0.7,
    "n_tokens": 120,
    "freq_penalty": 1,
    "seed": np.random.randint(100),
    "strength_multiplier": 4
}
headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)

print(response.json())

# Feature Attribution

Let's switch to an earlier SAE layer for attribution.

In [None]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.7.hook_resid_pre", # won't always be a hook point
    device = device
)

In [None]:
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple, Callable

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint


class SaeReconstructionCache(NamedTuple):
    sae_in: torch.Tensor
    feature_acts: torch.Tensor
    sae_out: torch.Tensor
    sae_error: torch.Tensor


def track_grad(tensor: torch.Tensor) -> None:
    """wrapper around requires_grad and retain_grad"""
    tensor.requires_grad_(True)
    tensor.retain_grad()


@dataclass
class ApplySaesAndRunOutput:
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]

    def zero_grad(self) -> None:
        """Helper to zero grad all tensors in this object."""
        self.model_output.grad = None
        for act in self.model_activations.values():
            act.grad = None
        for cache in self.sae_activations.values():
            cache.sae_in.grad = None
            cache.feature_acts.grad = None
            cache.sae_out.grad = None
            cache.sae_error.grad = None


def apply_saes_and_run(
    model: HookedTransformer,
    saes: dict[str, SAE],
    input: Any,
    include_error_term: bool = True,
    track_model_hooks: list[str] | None = None,
    return_type: Literal["logits", "loss"] = "logits",
    track_grads: bool = False,
) -> ApplySaesAndRunOutput:
    """
    Apply the SAEs to the model at the specific hook points, and run the model.
    By default, this will include a SAE error term which guarantees that the SAE
    will not affect model output. This function is designed to work correctly with
    backprop as well, so it can be used for gradient-based feature attribution.

    Args:
        model: the model to run
        saes: the SAEs to apply
        input: the input to the model
        include_error_term: whether to include the SAE error term to ensure the SAE doesn't affect model output. Default True
        track_model_hooks: a list of hook points to record the activations and gradients. Default None
        return_type: this is passed to the model.run_with_hooks function. Default "logits"
        track_grads: whether to track gradients. Default False
    """

    fwd_hooks = []
    bwd_hooks = []

    sae_activations: dict[str, SaeReconstructionCache] = {}
    model_activations: dict[str, torch.Tensor] = {}

    # this hook just track the SAE input, output, features, and error. If `track_grads=True`, it also ensures
    # that requires_grad is set to True and retain_grad is called for intermediate values.
    def reconstruction_hook(sae_in: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        sae = saes[hook_point]
        feature_acts = sae.encode(sae_in)
        sae_out = sae.decode(feature_acts)
        sae_error = (sae_in - sae_out).detach().clone()
        if track_grads:
            track_grad(sae_error)
            track_grad(sae_out)
            track_grad(feature_acts)
            track_grad(sae_in)
        sae_activations[hook_point] = SaeReconstructionCache(
            sae_in=sae_in,
            feature_acts=feature_acts,
            sae_out=sae_out,
            sae_error=sae_error,
        )

        if include_error_term:
            return sae_out + sae_error
        return sae_out

    def sae_bwd_hook(output_grads: torch.Tensor, hook: HookPoint):  # noqa: ARG001
        # this just passes the output grads to the input, so the SAE gets the same grads despite the error term hackery
        return (output_grads,)

    # this hook just records model activations, and ensures that intermediate activations have gradient tracking turned on if needed
    def tracking_hook(hook_input: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        model_activations[hook_point] = hook_input
        if track_grads:
            track_grad(hook_input)
        return hook_input

    for hook_point in saes.keys():
        fwd_hooks.append(
            (hook_point, partial(reconstruction_hook, hook_point=hook_point))
        )
        bwd_hooks.append((hook_point, sae_bwd_hook))
    for hook_point in track_model_hooks or []:
        fwd_hooks.append((hook_point, partial(tracking_hook, hook_point=hook_point)))

    # now, just run the model while applying the hooks
    with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks):
        model_output = model(input, return_type=return_type)

    return ApplySaesAndRunOutput(
        model_output=model_output,
        model_activations=model_activations,
        sae_activations=sae_activations,
    )

In [None]:
from dataclasses import dataclass
from transformer_lens.hook_points import HookPoint
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint

EPS = 1e-8

torch.set_grad_enabled(True)
@dataclass
class AttributionGrads:
    metric: torch.Tensor
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]


@dataclass
class Attribution:
    model_attributions: dict[str, torch.Tensor]
    model_activations: dict[str, torch.Tensor]
    model_grads: dict[str, torch.Tensor]
    sae_feature_attributions: dict[str, torch.Tensor]
    sae_feature_activations: dict[str, torch.Tensor]
    sae_feature_grads: dict[str, torch.Tensor]
    sae_errors_attribution_proportion: dict[str, float]


def calculate_attribution_grads(
    model: HookedSAETransformer,
    prompt: str,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> AttributionGrads:
    """
    Wrapper around apply_saes_and_run that calculates gradients wrt to the metric_fn.
    Tracks grads for both SAE feature and model neurons, and returns them in a structured format.
    """
    output = apply_saes_and_run(
        model,
        saes=include_saes or {},
        input=prompt,
        return_type="logits" if return_logits else "loss",
        track_model_hooks=track_hook_points,
        include_error_term=include_error_term,
        track_grads=True,
    )
    metric = metric_fn(output.model_output)
    output.zero_grad()
    metric.backward()
    return AttributionGrads(
        metric=metric,
        model_output=output.model_output,
        model_activations=output.model_activations,
        sae_activations=output.sae_activations,
    )


def calculate_feature_attribution(
    model: HookedSAETransformer,
    input: Any,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> Attribution:
    """
    Calculate feature attribution for SAE features and model neurons following
    the procedure in https://transformer-circuits.pub/2024/march-update/index.html#feature-heads.
    This include the SAE error term by default, so inserting the SAE into the calculation is
    guaranteed to not affect the model output. This can be disabled by setting `include_error_term=False`.

    Args:
        model: The model to calculate feature attribution for.
        input: The input to the model.
        metric_fn: A function that takes the model output and returns a scalar metric.
        track_hook_points: A list of model hook points to track activations for, if desired
        include_saes: A dictionary of SAEs to include in the calculation. The key is the hook point to apply the SAE to.
        return_logits: Whether to return the model logits or loss. This is passed to TLens, so should match whatever the metric_fn expects (probably logits)
        include_error_term: Whether to include the SAE error term in the calculation. This is recommended, as it ensures that the SAE will not affecting the model output.
    """
    # first, calculate gradients wrt to the metric_fn.
    # these will be multiplied with the activation values to get the attributions
    outputs_with_grads = calculate_attribution_grads(
        model,
        input,
        metric_fn,
        track_hook_points,
        include_saes=include_saes,
        return_logits=return_logits,
        include_error_term=include_error_term,
    )
    model_attributions = {}
    model_activations = {}
    model_grads = {}
    sae_feature_attributions = {}
    sae_feature_activations = {}
    sae_feature_grads = {}
    sae_error_proportions = {}
    # this code is long, but all it's doing is multiplying the grads by the activations
    # and recording grads, acts, and attributions in dictionaries to return to the user
    with torch.no_grad():
        for name, act in outputs_with_grads.model_activations.items():
            assert act.grad is not None
            raw_activation = act.detach().clone()
            model_attributions[name] = (act.grad * raw_activation).detach().clone()
            model_activations[name] = raw_activation
            model_grads[name] = act.grad.detach().clone()
        for name, act in outputs_with_grads.sae_activations.items():
            assert act.feature_acts.grad is not None
            assert act.sae_out.grad is not None
            raw_activation = act.feature_acts.detach().clone()
            sae_feature_attributions[name] = (
                (act.feature_acts.grad * raw_activation).detach().clone()
            )
            sae_feature_activations[name] = raw_activation
            sae_feature_grads[name] = act.feature_acts.grad.detach().clone()
            if include_error_term:
                assert act.sae_error.grad is not None
                error_grad_norm = act.sae_error.grad.norm().item()
            else:
                error_grad_norm = 0
            sae_out_norm = act.sae_out.grad.norm().item()
            sae_error_proportions[name] = error_grad_norm / (
                sae_out_norm + error_grad_norm + EPS
            )
        return Attribution(
            model_attributions=model_attributions,
            model_activations=model_activations,
            model_grads=model_grads,
            sae_feature_attributions=sae_feature_attributions,
            sae_feature_activations=sae_feature_activations,
            sae_feature_grads=sae_feature_grads,
            sae_errors_attribution_proportion=sae_error_proportions,
        )
        
        
# prompt = " Tiger Woods plays the sport of"
# pos_token = model.tokenizer.encode(" golf")[0]
prompt = "In the beginning, God created the heavens and the"
pos_token = model.tokenizer.encode(" earth")
neg_token = model.tokenizer.encode(" sky")
def metric_fn(logits: torch.tensor, pos_token: torch.tensor =pos_token, neg_token: torch.Tensor=neg_token) -> torch.Tensor:
    return logits[0,-1,pos_token] - logits[0,-1,neg_token]

feature_attribution_df = calculate_feature_attribution(
    input = prompt,
    model = model,
    metric_fn = metric_fn,
    include_saes={sae.cfg.hook_name: sae},
    include_error_term=True,
    return_logits=True,
)


In [None]:
from transformer_lens.utils import test_prompt
test_prompt(prompt, model.to_string(pos_token), model)

In [None]:
tokens = model.to_str_tokens(prompt)
unique_tokens = [f"{i}/{t}" for i, t in enumerate(tokens)]

px.bar(x = unique_tokens,
       y = feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0].sum(-1).detach().cpu().numpy())

In [None]:
def convert_sparse_feature_to_long_df(sparse_tensor: torch.Tensor) -> pd.DataFrame:
    """
    Convert a sparse tensor to a long format pandas DataFrame.
    """
    df = pd.DataFrame(sparse_tensor.detach().cpu().numpy())
    df_long = df.melt(ignore_index=False, var_name='column', value_name='value')
    df_long.columns = ["feature", "attribution"]
    df_long_nonzero = df_long[df_long['attribution'] != 0]
    df_long_nonzero = df_long_nonzero.reset_index().rename(columns={'index': 'position'})
    return df_long_nonzero

df_long_nonzero = convert_sparse_feature_to_long_df(feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0])
df_long_nonzero.sort_values("attribution", ascending=False)

In [None]:
for i, v in df_long_nonzero.query("position==8").groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=int(i))
    display(IFrame(html, width=1200, height=300))

In [None]:
for i, v in df_long_nonzero.groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=int(i))
    display(IFrame(html, width=1200, height=300))

# Advanced: Making U-Maps