# Evaluating SAEs with SAE Lens Evals

## Overview

In this tutorial we will cover the use of our evaluation tools on SAEs. As there is no single metric for SAE quality, we include a number of metrics from multiple categories to help you assess whether an SAE meets your usage requirements. In many cases, you will want to compare multiple SAEs for a given model layer in order to determine which performs best on specific metrics. Below, we explain how to run the evaluations, how each metric is calculated, and how to interpret the values shown.

## Imports & Installs

In [None]:
try:
    import google.colab  # type: ignore
    from google.colab import output

    COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard
except:
    COLAB = False
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import json
import numpy as np

torch.set_grad_enabled(False);

## Using the Command Line Utility



To see which SAEs are available for a specified model, you can simply run the following code snippet. For running evals, the primary things to pay attention to are the release name and the format used in `saes_map` to specify the target layer. These will be used as input for the evals runner.

In [None]:
import pandas as pd
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
    columns=[
        "expected_var_explained",
        "expected_l0",
        "config_overrides",
        "conversion_func",
    ],
    inplace=True,
)
df[
    df.release.str.contains("gpt2")
]  # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model.

Let's run evals for 3 SAEs:
- Apollo's End-to-End SAEs with downstream reconstruction loss (`gpt2-small-res_scefr-ajt`)
- OpenAI's TopK SAEs (`gpt2-small-resid-mid-v5-32k`)
- Joseph Bloom's Residual Stream SAEs (`gpt2-small-res-jb`)

These SAEs are not perfectly comparable because they are not all the same width (32K for OpenAI vs. 24K for the others), and they are not all trained at the same part of the layer (resid-mid for OpenAI TopK, resid-pre for Joseph Bloom's and Apollo's), but they will suffice for our tutorial.

Currently, the evals are designed to be run through the CLI. Simply copy and paste the command below. (Note: Omit the `poetry run` portion if you are not using a poetry environment.)

```bash
poetry run python sae_lens/evals.py "gpt2-small-res_scefr-ajt|gpt2-small-resid-mid-v5-32k|gpt2-small-res-jb" "blocks.10.*" \
    --batch_size_prompts 16 \
    --n_eval_sparsity_variance_batches 200 \
    --n_eval_reconstruction_batches 20 \
    --output_dir "demo_eval_results" \ 
    --verbose
```

`evals.py` accepts a wide range of arguments, most of which are used to control whether to evaluate a particular metric. You can see the full list in the source code, but the most important arguments are the ones we've used here:
- `"gpt2-small-res_scefr-ajt|gpt2-small-resid-mid-v5-32k|gpt2-small-res-jb"` This is a regex pattern used to match SAE release names. The script allows you to specify multiple or just one.
- `"blocks.10.*"` This is another regex pattern for SAE block names. This particular pattern applies to each of the releases we're targeting, so layer 10 (resid-pre or resid-mid) will be evaluated for each of them.
- `batch_size_prompts 16` Here, we're setting the batch size for the dataset of evaluation prompts to 16.
- `n_eval_sparsity_variance_batches 200` The total number of  evaluation prompt batches to be used when evaluating sparsity and variance metrics.
- `n_eval_reconstruction_batches 20` Number of prompt batches to use for reconstruction evaluation.

Once the evals are run, we can open them:

In [None]:
output_dir = "../demo_eval_results"

# list files in the output directory
print(os.listdir(output_dir))

An example of the results is printed below. The evaluation script will produce a list of nested dictionaries, with each unique SAE (release and SAE ID combination) as an item in the list. Dictionaries for individual models are saved as files, and a file with all results combined is saved as well. Within the nested dictionary, we store metadata about the SAE, the `eval_cfg` that was used to generate the results, a dictionary containing SAE-level metrics (`metrics`), another dictionary containing feature-level metrics (`feature_metrics`), and finally the SAE config, which contains information about the SAE itself, its training, the dataset used for evaluation, and various other details.

In [None]:
# get the first file in the output directory
first_file = os.path.join(output_dir, os.listdir(output_dir)[0])
print(first_file)

# load the results json pretty print it
with open(first_file, "r") as f:
    results = json.load(f)

print(json.dumps(results["sae_set"], indent=4))
print(json.dumps(results["eval_cfg"], indent=4))
print(json.dumps(results["metrics"], indent=4))
# print(json.dumps(results['feature_metrics'], indent=4)) # uncomment to see feature metrics

In [None]:
print(results["metrics"].keys())

### Evaluation Configuration

The `eval_cfg` dictionary shows what parameters were used to generate the evaluation. The Boolean items are used to set whether to run particular evaluations; by default, all are set to True. Some other parameters of note:
- `context_size`: The length of the prompts used in the evaluation.
- `dataset`: This is a path to the HuggingFace dataset used for evaluation. Often, this will be the same dataset the SAE was trained on, but this is not always the case--sometimes that dataset is not publicly available.
- `library_version`: The version of SAE Lens used.
- `git_hash`: The hash of the specific commit used.
- `batch_size_prompts`: Number of prompts per batch, for each batch of the evaluation runs.
- `n_eval_reconstruction_batches`: Number of batches used to evaluate reconstruction metrics. This number sometimes needs tuning to ensure that you are getting sufficient data and some level of metric stability.
- `n_eval_sparsity_variance_batches`: Number of batches used to evaluate sparsity and variance metrics. This number needs tuning as well; e.g. if it is too low, the number of dead neurons will appear artificially high, and the feature density histogram will appear less smooth. Generally, this number should be much higher than the number of reconstruction batches.

## Making Sense of Results

### Understanding SAE Metrics

The SAE metrics we compute can be grouped into six distinct categories. Let's explore each category and its associated metrics in detail. We'll then compare our three SAEs on each metric.


#### Model Performance Preservation

These metrics indicate how much the SAE affects the underlying model's performance and output:

- **ce_loss_with_sae**: Cross-entropy loss of the model's output after applying the SAE.
- **ce_loss_without_sae**: Baseline cross-entropy loss of the original model without the SAE.
- **ce_loss_with_ablation**: Cross-entropy loss when the relevant activations are set to zero.
- **ce_loss_score**: A derived metric comparing the cross-entropy losses with and without the SAE.

#### Model Behavior Preservation

These metrics indicate differences between the distributions of logit predictions with and without the SAE.

- **kl_div_score**: A derived metric comparing the KL divergence with SAE to the KL divergence with ablation.
- **kl_div_with_sae**: Kullback-Leibler divergence between the original model's output distribution and the distribution after applying the SAE. Lower values indicate better preservation of the model's behavior.
- **kl_div_with_ablation**: KL divergence between the original model's output and the output when the relevant activations are set to zero. This serves as a baseline for comparison.

#### Reconstruction Quality

These metrics assess how well the SAE can reconstruct the original input activations at the target layer:

- **mse**: Mean Squared Error between the input and the reconstruction. Lower values indicate better reconstruction accuracy.
- **explained_variance**: The proportion of variance in the input that is explained by the SAE's reconstruction. Higher values indicate better reconstruction quality.
- **cossim**: Cosine similarity between the input and the reconstruction. Values closer to 1 indicate better preservation of the input's direction.

#### Shrinkage

These metrics show how well the SAE preserves the overall magnitude of the input:

- **l2_norm_in**: The L2 norm (Euclidean norm) of the input activations.
- **l2_norm_out**: The L2 norm of the output activations after passing through the SAE.
- **l2_ratio**: The ratio of the L2 norm of the SAE output to the L2 norm of the SAE input. A value close to 1 indicates good preservation of the input's magnitude.
- **relative_reconstruction_bias**: Measures the bias in reconstruction. Values closer to 1 indicate less bias.

#### Sparsity

These metrics measure how sparse the SAE's activations are:

- **l0**: The L0 "norm" of the SAE's activations, which is the number of non-zero elements. It measures how many features are active.
- **l1**: The L1 norm of the SAE's activations, which is the sum of the absolute values. It's another measure of sparsity.

#### Token Statistics

These metrics provide context about the amount of data used in the evaluation:

- **total_tokens_eval_reconstruction**: The total number of tokens used in evaluating reconstruction metrics.
- **total_tokens_eval_sparsity_variance**: The total number of tokens used in evaluating sparsity and variance metrics.

Now let's take a look at the metrics we've collected and see how our SAEs fare.

In [None]:
import os
import json
from collections import defaultdict

# Define the metric groups
metric_groups = {
    "Reconstruction Quality": [
        "l2_ratio",
        "relative_reconstruction_bias",
        "explained_variance",
        "mse",
        "cossim",
    ],
    "Magnitude Preservation": ["l2_norm_in", "l2_norm_out"],
    "Model Behavior Preservation": [
        "kl_div_with_sae",
        "kl_div_with_ablation",
        "kl_div_score",
        "ce_loss_with_sae",
        "ce_loss_without_sae",
        "ce_loss_with_ablation",
        "ce_loss_score",
    ],
    "Sparsity": ["l0", "l1"],
    "Evaluation Scale": [
        "total_tokens_eval_reconstruction",
        "total_tokens_eval_sparsity_variance",
    ],
}

# Open all_eval_results.json
all_eval_results_file = os.path.join(output_dir, "all_eval_results.json")
with open(all_eval_results_file, "r") as f:
    all_eval_results = json.load(f)

# Sort all_eval_results by sae_set and sae_id
all_eval_results.sort(key=lambda x: (x["sae_set"], x["sae_id"]))

# Print all unique SAE sets in order
print("SAE Sets evaluated:")
unique_sae_sets = []
for result in all_eval_results:
    if result["sae_set"] not in unique_sae_sets:
        unique_sae_sets.append(result["sae_set"])
print(", ".join(unique_sae_sets))
print("\n")


# Function to print metrics for all SAEs
def print_all_sae_metrics(all_results):
    # Get unique combinations of dataset and context size
    configs = set(
        (r["eval_cfg"]["dataset"], r["eval_cfg"]["context_size"]) for r in all_results
    )

    for dataset, context_size in configs:
        print(f"Dataset: {dataset}, Context Size: {context_size}")
        print("-" * 50)

        # Filter results for this dataset and context size
        filtered_results = [
            r
            for r in all_results
            if r["eval_cfg"]["dataset"] == dataset
            and r["eval_cfg"]["context_size"] == context_size
        ]

        # Print SAE identifiers
        sae_ids = [f"{r['sae_set']} - {r['sae_id']}" for r in filtered_results]
        print("SAEs:", ", ".join(sae_ids))
        print()

        for group, metrics in metric_groups.items():
            print(f"\n{group}:")
            for metric in metrics:
                values = [r["metrics"].get(metric, "N/A") for r in filtered_results]
                formatted_values = [
                    f"{v:.4f}" if isinstance(v, float) else str(v) for v in values
                ]
                print(f"  {metric}: {formatted_values}")
        print("\n" + "=" * 50 + "\n")


# Print metrics for all SAEs
print_all_sae_metrics(all_eval_results)

We can see some interesting differences here among the SAEs. For example, the Apollo end-to-end SAE shows a much higher MSE between the SAE input and reconstruction, and also shows a significant change in L2 norm of the output vs. input. It also has anomalous explained variance and much lower input-output cosine sim. Yet, it performs the best in terms of KL divergence. These differences are probably due to the very different loss function that these SAEs were trained on.

We can also see that the OpenAI TopK SAE has much more sparsity than the other two SAEs (which entails lower L0 and L1 values), yet still performs similarly to the JB SAEs in terms of fidelity (CE loss score) and other model behavior preservation metrics. These SAEs were trained with a TopK activation function, which enforces the desired level of sparsity.

Ideally, we would like to have perfect model behavior preservation, a high degree of sparsity, perfect reconstruction quality, and so on. Unfortunately, in reality there are often tradeoffs between these desiderata. Nevertheless, sometimes new SAE methods advance the Pareto frontier, allowing us to find objectively somewhat better SAEs for our use case.

In this case, since the OpenAI TopK SAE performed quite well, we might choose this one for a future project.

## Visualizing Featurewise Metrics

In our evaluation process, we also collect metrics for individual SAE features. Let's go through these category by category, starting with feature density metrics. We collect two of these:
- Feature density itself, which we generally plot as log10 feature density
- Consistent activation heuristic, which helps identify multi-token features along with dense features in general.

Let's plot these first, and we'll then explain them.

### Feature Density Metrics

In [None]:
results_file = os.path.join(
    output_dir,
    "gpt2-small-resid-mid-v5-32k-blocks.10.hook_resid_mid_128_Skylion007_openwebtext.json",
)
with open(results_file, "r") as f:
    results = json.load(f)

feature_stats = results["feature_metrics"]
feature_df = pd.DataFrame.from_dict(feature_stats)

feature_df["feature_name"] = [i for i in range(len(feature_df))]

log10_feature_density = np.log10(feature_df["feature_density"] + 1e-10)
log10_consistent_activation_heuristic = np.log10(
    feature_df["consistent_activation_heuristic"] + 1e-10
)

px.histogram(
    feature_df,
    x=log10_feature_density,
    nbins=100,
    title="Log10 Feature Density",
    labels={"x": "log10_feature_density"},  # Add this line
).show()

px.scatter(
    feature_df,
    x=log10_feature_density,
    y=log10_consistent_activation_heuristic,
    marginal_y="histogram",
    marginal_x="histogram",
    hover_name="feature_name",
    title="Log10 Feature Density vs Consistent Activation Heuristic",
    labels={"x": "log10_feature_density", "y": "log10_consistent_activation_heuristic"},
    height=800,
    width=1200,
    opacity=0.2,
).show()

Let's now discuss how these are calculated and how to interpret them.

#### Feature Density

According to [Anthropic](https://transformer-circuits.pub/2023/monosemantic-features#appendix-feature-density), an important proxy for autoencoder performance is *feature density*. Feature Density is defined as the fraction of tokens on which the feature fires / activates. In practice, we plot a histogram of `-log10(feature_density)` in order to see the distribution of feature densities. 

##### Calculation

To calculate feature density, we count how many times a feature fires over a very large dataset.

```python
feature_density = total_feature_acts / total_tokens
log10_feature_density = np.log10(feature_density + 1e-10)
```

##### Interpretation

- **dense latents**: Latents that fire every on in one hundred tokens or fewer can be considered dense and show up between -2 and 0 on the histogram. Dense features are less likely to be interpretable and too many dense features might enable summary statistics such as MSE to appear deceptively good. 
- **dead latents**: Latents that fail to fire on any examples in the dataset may be "dead" meaning that they may never fire and are essentially wasted capacity of the SAE. We see these as a column at -10 on the histogram.
- **bimodality**: Sometimes the distribution appears bimodal (having two peaks). We're not entirely sure what causes this (if you know, let us know!)

#### Consistent Activation Heuristic (CAH)

The Consistent Activation Heuristic is a measure of how "contextiness". High scores on CAH correspond to latents that will fire on almost every token in a context that they fire on (eg: tracking that a piece of text is a about a topic). Anecdotally, these latents may be more useful for steering than others. By contrast, low CAH latents tend to fire once per prompt (maybe on a specific token). 

##### Calculation

```python
consistent_activation_heuristic = total_feature_acts / total_feature_prompts
log10_consistent_activation_heuristic = np.log10(consistent_activation_heuristic + 1e-10)
```

Where:
- `total_feature_acts`: The total number of times each feature was activated across all tokens.
- `total_feature_prompts`: The number of prompts in which each feature was activated at least once.

#### Combining CAH and Log10 Feature Density

When visualizing these metrics together (e.g., in a scatter plot), you can identify:

1. Dense, contexty latents (high log10 density, high CAH)
2. Sparse, contexty latents (low log10 density, high CAH)
3. Dense, few token or single token latents (high log10 density, low CAH)
4. Sparse, few token or single token latents (low log10 density, low CAH)

### Featurewise Weight Based Metrics

In [None]:
import pandas as pd
import plotly.express as px

# Create a DataFrame from the results
df = pd.DataFrame(
    {
        "feature_id": [
            f"feature_{i}"
            for i in range(
                len(results["feature_metrics"]["encoder_decoder_cosine_sim"])
            )
        ],
        "encoder_bias": results["feature_metrics"]["encoder_bias"],
        "encoder_norm": results["feature_metrics"]["encoder_norm"],
        "encoder_decoder_cosine_sim": results["feature_metrics"][
            "encoder_decoder_cosine_sim"
        ],
        "feature_density": results["feature_metrics"]["feature_density"],
        "consistent_activation_heuristic": results["feature_metrics"][
            "consistent_activation_heuristic"
        ],
    }
)

df["log10_feature_density"] = np.log10(df["feature_density"] + 1e-10)
df["log10_consistent_activation_heuristic"] = np.log10(
    df["consistent_activation_heuristic"] + 1e-10
)

px.scatter(
    df,
    x=log10_feature_density,
    y=log10_consistent_activation_heuristic,
    marginal_y="histogram",
    marginal_x="histogram",
    hover_name="feature_id",
    title="Log10 Feature Density vs Consistent Activation Heuristic",
    labels={"x": "log10_feature_density", "y": "log10_consistent_activation_heuristic"},
    height=800,
    width=1200,
    opacity=0.2,
).show()


def create_cosine_sim_plot(feature_df, title):
    fig = px.histogram(feature_df, x="encoder_decoder_cosine_sim", title=title)
    fig.update_layout(height=400, width=600)
    return fig


# # Create and display the encoder-decoder cosine similarity histogram
cosine_sim_hist = create_cosine_sim_plot(df, "Encoder-Decoder Cosine Similarity")
cosine_sim_hist.show()

#### Weight Based Metrics

Further insights into SAE quality can be derived from weight based analysis of the SAE. Specifically:
- `Encoder Bias`: The encoder bias term tends to roughly correspond to feature density. 
- `Encoder Norm`: A larger encoder norm tends to be present in features that activate on a smaller number of vocabulary items ([according to the mechanistic interpretability team OpenAI](https://arxiv.org/pdf/2406.04093))
- `Decoder - Encoder Cosine Similarity`: Early SAEs were trained with tied encoder / decoder weights, but performance improvements were found by untieing these weights. Checking the similarity between encoder and decoder weights may be informative (for example, differences between encoder and decoder weights may be related to [feature absorption](https://arxiv.org/abs/2409.14507))

##### Interpretation

- `Encoder Bias` and `Encoder Norm` likely track something important but it's not yet clear from the literature whether there's a standard way to use these to evaluate SAEs.
- On the other hand, it's likely better `Decoder - Encoder Cosine Similarity` is generally higher.