# Evaluating SAEs with SAE Lens Evals

## Overview

In this tutorial we will cover the use of our evaluation tools on SAEs. As there is no single metric for SAE quality, we include a number of metrics from multiple categories to help you assess whether an SAE meets your usage requirements. In many cases, you will want to compare multiple SAEs for a given model layer in order to determine which performs best on specific metrics. Below, we explain how to run the evaluations, how each metric is calculated, and how to interpret the values shown.

## Imports & Installs

In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import json
import numpy as np 

torch.set_grad_enabled(False);

## Using the Command Line Utility



To see which SAEs are available for a specified model, you can simply run the following code snippet. For running evals, the primary things to pay attention to are the release name and the format used in `saes_map` to specify the target layer. These will be used as input for the evals runner.

In [77]:
import pandas as pd
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
df[df.release.str.contains("gemma-scope")] # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. 

Unnamed: 0,release,repo_id,model,saes_map,neuronpedia_id
gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/average_l0_106': 'layer_...,"{'layer_10/width_131k/average_l0_106': None, '..."
gemma-scope-27b-pt-res-canonical,gemma-scope-27b-pt-res-canonical,google/gemma-scope-27b-pt-res,gemma-2-27b,{'layer_10/width_131k/canonical': 'layer_10/wi...,{'layer_10/width_131k/canonical': 'gemma-2-27b...
gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/average_l0_104': 'layer_0/...,"{'layer_0/width_16k/average_l0_104': None, 'la..."
gemma-scope-2b-pt-att-canonical,gemma-scope-2b-pt-att-canonical,google/gemma-scope-2b-pt-att,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...,{'layer_0/width_16k/canonical': 'gemma-2-2b/0-...
gemma-scope-2b-pt-mlp,gemma-scope-2b-pt-mlp,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/average_l0_119': 'layer_0/...,"{'layer_0/width_16k/average_l0_119': None, 'la..."
gemma-scope-2b-pt-mlp-canonical,gemma-scope-2b-pt-mlp-canonical,google/gemma-scope-2b-pt-mlp,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...,{'layer_0/width_16k/canonical': 'gemma-2-2b/0-...
gemma-scope-2b-pt-res,gemma-scope-2b-pt-res,google/gemma-scope-2b-pt-res,gemma-2-2b,{'embedding/width_4k/average_l0_6': 'embedding...,"{'embedding/width_4k/average_l0_6': None, 'emb..."
gemma-scope-2b-pt-res-canonical,gemma-scope-2b-pt-res-canonical,google/gemma-scope-2b-pt-res,gemma-2-2b,{'layer_0/width_16k/canonical': 'layer_0/width...,{'layer_0/width_16k/canonical': 'gemma-2-2b/0-...
gemma-scope-9b-it-res,gemma-scope-9b-it-res,google/gemma-scope-9b-it-res,gemma-2-9b,{'layer_20/width_131k/average_l0_13': 'layer_2...,"{'layer_20/width_131k/average_l0_13': None, 'l..."
gemma-scope-9b-it-res-canonical,gemma-scope-9b-it-res-canonical,google/gemma-scope-9b-it-res,gemma-2-9b-it,{'layer_9/width_16k/canonical': 'layer_9/width...,{'layer_9/width_16k/canonical': 'gemma-2-9b-it...


Let's run evals for 3 SAEs:
- Apollo's End-to-End SAEs with downstream reconstruction loss (`gpt2-small-res_scefr-ajt`)
- OpenAI's TopK SAEs (`gpt2-small-resid-mid-v5-32k`)
- Joseph Bloom's Residual Stream SAEs (`gpt2-small-res-jb`)

These SAEs are not perfectly comparable because they are not all the same width (32K for OpenAI vs. 24K for the others), and they are not all trained at the same part of the layer (resid-mid for OpenAI TopK, resid-pre for Joseph Bloom's and Apollo's), but they will suffice for our tutorial.

Currently, the evals are designed to be run through the CLI. Simply copy and paste the command below. (Note: Omit the `poetry run` portion if you are not using a poetry environment.)

```bash
poetry run python sae_lens/evals.py "gpt2-small-res_scefr-ajt|gpt2-small-resid-mid-v5-32k|gpt2-small-res-jb" "blocks.9.*" \
    --batch_size_prompts 16 \
    --n_eval_sparsity_variance_batches 200 \
    --n_eval_reconstruction_batches 20 \
    --output_dir "demo_eval_results"
```

`evals.py` accepts a wide range of arguments, most of which are used to control whether to evaluate a particular metric. You can see the full list in the source code, but the most important arguments are the ones we've used here:
- `"gpt2-small-res_scefr-ajt|gpt2-small-resid-mid-v5-32k|gpt2-small-res-jb"` This is a regex pattern used to match SAE release names. The script allows you to specify multiple or just one.
- `"blocks.10.*"` This is another regex pattern for SAE block names. This particular pattern applies to each of the releases we're targeting, so layer 10 (resid-pre or resid-mid) will be evaluated for each of them.
- `batch_size_prompts 16` Here, we're setting the batch size for the dataset of evaluation prompts to 16.
- `n_eval_sparsity_variance_batches 200` The total number of  evaluation prompt batches to be used when evaluating sparsity and variance metrics.
- `n_eval_reconstruction_batches 20` Number of prompt batches to use for reconstruction evaluation.

Once the evals are run, we can open them:

In [43]:
output_dir = "../demo_eval_results"

# list files in the output directory
print(os.listdir(output_dir))

['gpt2-small-res-jb-blocks.10.hook_resid_pre_128_Skylion007_openwebtext.json', 'all_eval_results.json', 'gpt2-small-res_scefr-ajt-blocks.10.hook_resid_pre_128_Skylion007_openwebtext.json', 'gpt2-small-resid-mid-v5-32k-blocks.10.hook_resid_mid_128_Skylion007_openwebtext.json', 'all_eval_results.csv']


In [26]:
type(results)

dict

An example of the results is printed below. The evaluation script will produce a list of nested dictionaries, with each unique SAE (release and SAE ID combination) as an item in the list. Dictionaries for individual models are saved as files, and a file with all results combined is saved as well. Within the nested dictionary, we store metadata about the SAE, the `eval_cfg` that was used to generate the results, a dictionary containing SAE-level metrics (`metrics`), another dictionary containing feature-level metrics (`feature_metrics`), and finally the SAE config, which contains information about the SAE itself, its training, the dataset used for evaluation, and various other details.

In [36]:

# get the first file in the output directory
first_file = os.path.join(output_dir, os.listdir(output_dir)[0])
print(first_file)

# load the results json pretty print it
with open(first_file, "r") as f:
    results = json.load(f)

print(json.dumps(results['sae_set'], indent=4))
print(json.dumps(results['eval_cfg'], indent=4))
print(json.dumps(results['metrics'], indent=4))
#print(json.dumps(results['feature_metrics'], indent=4)) # uncomment to see feature metrics

../demo_eval_results/gpt2-small-res-jb-blocks.10.hook_resid_pre_128_Skylion007_openwebtext.json
"gpt2-small-res-jb"
{
    "context_size": 128,
    "dataset": "Skylion007/openwebtext",
    "library_version": "3.17.1",
    "git_hash": "81b6e04",
    "batch_size_prompts": 16,
    "n_eval_reconstruction_batches": 20,
    "compute_kl": true,
    "compute_ce_loss": true,
    "n_eval_sparsity_variance_batches": 200,
    "compute_l2_norms": true,
    "compute_sparsity_metrics": true,
    "compute_variance_metrics": true,
    "compute_featurewise_density_statistics": true,
    "compute_featurewise_weight_based_metrics": true
}
{
    "kl_div_with_sae": 0.21521934866905212,
    "kl_div_with_ablation": 8.432144165039062,
    "ce_loss_with_sae": 3.740305185317993,
    "ce_loss_without_sae": 3.550839900970459,
    "ce_loss_with_ablation": 11.913066864013672,
    "kl_div_score": 0.9744763200845896,
    "ce_loss_score": 0.9773427239914828,
    "l2_norm_in": 153.0411834716797,
    "l2_norm_out": 139.36

In [48]:
print(results['metrics'].keys())

dict_keys(['kl_div_with_sae', 'kl_div_with_ablation', 'ce_loss_with_sae', 'ce_loss_without_sae', 'ce_loss_with_ablation', 'kl_div_score', 'ce_loss_score', 'l2_norm_in', 'l2_norm_out', 'l2_ratio', 'relative_reconstruction_bias', 'l0', 'l1', 'explained_variance', 'mse', 'cossim', 'total_tokens_eval_reconstruction', 'total_tokens_eval_sparsity_variance'])


## Making Sense of Results

### Understanding SAE Metrics

The SAE metrics we compute can be grouped into five distinct categories. Let's explore each category and its associated metrics in detail. We'll then compare our three SAEs on each metric.

#### Reconstruction Quality

These metrics assess how well the SAE can reconstruct the original input:

- **l2_ratio**: The ratio of the L2 norm of the SAE output to the L2 norm of the SAE input. A value close to 1 indicates good preservation of the input's magnitude.
- **relative_reconstruction_bias**: Measures the bias in reconstruction. Values closer to 1 indicate less bias.
- **explained_variance**: The proportion of variance in the input that is explained by the SAE's reconstruction. Higher values indicate better reconstruction quality.
- **mse**: Mean Squared Error between the input and the reconstruction. Lower values indicate better reconstruction accuracy.
- **cossim**: Cosine similarity between the input and the reconstruction. Values closer to 1 indicate better preservation of the input's direction.

#### Magnitude Preservation

These metrics show how well the SAE preserves the overall magnitude of the input:

- **l2_norm_in**: The L2 norm (Euclidean norm) of the input activations.
- **l2_norm_out**: The L2 norm of the output activations after passing through the SAE.

#### Model Behavior Preservation

These metrics indicate how much the SAE affects the underlying model's performance and output distributions:

- **kl_div_with_sae**: Kullback-Leibler divergence between the original model's output distribution and the distribution after applying the SAE. Lower values indicate better preservation of the model's behavior.
- **kl_div_with_ablation**: KL divergence between the original model's output and the output when the relevant activations are set to zero. This serves as a baseline for comparison.
- **kl_div_score**: A derived metric comparing the KL divergence with SAE to the KL divergence with ablation.
- **ce_loss_with_sae**: Cross-entropy loss of the model's output after applying the SAE.
- **ce_loss_without_sae**: Baseline cross-entropy loss of the original model without the SAE.
- **ce_loss_with_ablation**: Cross-entropy loss when the relevant activations are set to zero.
- **ce_loss_score**: A derived metric comparing the cross-entropy losses with and without the SAE.

#### Sparsity

These metrics measure how sparse the SAE's activations are:

- **l0**: The L0 "norm" of the SAE's activations, which is the number of non-zero elements. It measures how many features are active.
- **l1**: The L1 norm of the SAE's activations, which is the sum of the absolute values. It's another measure of sparsity.

#### Evaluation Scale

These metrics provide context about the amount of data used in the evaluation:

- **total_tokens_eval_reconstruction**: The total number of tokens used in evaluating reconstruction metrics.
- **total_tokens_eval_sparsity_variance**: The total number of tokens used in evaluating sparsity and variance metrics.

Now let's take a look at the metrics we've collected and see how our SAEs fare.

In [70]:
import os
import json
from collections import defaultdict

# Define the metric groups
metric_groups = {
    "Reconstruction Quality": ["l2_ratio", "relative_reconstruction_bias", "explained_variance", "mse", "cossim"],
    "Magnitude Preservation": ["l2_norm_in", "l2_norm_out"],
    "Model Behavior Preservation": [
        "kl_div_with_sae", 
        "kl_div_with_ablation", 
        "kl_div_score",
        "ce_loss_with_sae", 
        "ce_loss_without_sae", 
        "ce_loss_with_ablation",
        "ce_loss_score"
    ],
    "Sparsity": ["l0", "l1"],
    "Evaluation Scale": ["total_tokens_eval_reconstruction", "total_tokens_eval_sparsity_variance"]
}

# Open all_eval_results.json
all_eval_results_file = os.path.join(output_dir, "all_eval_results.json")
with open(all_eval_results_file, "r") as f:
    all_eval_results = json.load(f)

# Sort all_eval_results by sae_set and sae_id
all_eval_results.sort(key=lambda x: (x['sae_set'], x['sae_id']))

# Print all unique SAE sets in order
print("SAE Sets evaluated:")
unique_sae_sets = []
for result in all_eval_results:
    if result['sae_set'] not in unique_sae_sets:
        unique_sae_sets.append(result['sae_set'])
print(", ".join(unique_sae_sets))
print("\n")

# Function to print metrics for all SAEs
def print_all_sae_metrics(all_results):
    # Get unique combinations of dataset and context size
    configs = set((r['eval_cfg']['dataset'], r['eval_cfg']['context_size']) for r in all_results)
    
    for dataset, context_size in configs:
        print(f"Dataset: {dataset}, Context Size: {context_size}")
        print("-" * 50)
        
        # Filter results for this dataset and context size
        filtered_results = [r for r in all_results if r['eval_cfg']['dataset'] == dataset and r['eval_cfg']['context_size'] == context_size]
        
        # Print SAE identifiers
        sae_ids = [f"{r['sae_set']} - {r['sae_id']}" for r in filtered_results]
        print("SAEs:", ", ".join(sae_ids))
        print()
        
        for group, metrics in metric_groups.items():
            print(f"\n{group}:")
            for metric in metrics:
                values = [r['metrics'].get(metric, 'N/A') for r in filtered_results]
                formatted_values = [f"{v:.4f}" if isinstance(v, float) else str(v) for v in values]
                print(f"  {metric}: {formatted_values}")
        print("\n" + "=" * 50 + "\n")

# Print metrics for all SAEs
print_all_sae_metrics(all_eval_results)

SAE Sets evaluated:
gpt2-small-res-jb, gpt2-small-res_scefr-ajt, gpt2-small-resid-mid-v5-32k


Dataset: Skylion007/openwebtext, Context Size: 128
--------------------------------------------------
SAEs: gpt2-small-res-jb - blocks.10.hook_resid_pre, gpt2-small-res_scefr-ajt - blocks.10.hook_resid_pre, gpt2-small-resid-mid-v5-32k - blocks.10.hook_resid_mid


Reconstruction Quality:
  l2_ratio: ['0.9101', '0.3319', '0.9604']
  relative_reconstruction_bias: ['0.9588', '1.2878', '1.0006']
  explained_variance: ['0.7925', '-0.4288', '0.8160']
  mse: ['1.1256', '11.7036', '1.2315']
  cossim: ['0.9496', '0.6807', '0.9582']

Magnitude Preservation:
  l2_norm_in: ['153.0412', '176.5965', '200.1225']
  l2_norm_out: ['139.3622', '81.3349', '193.1755']

Model Behavior Preservation:
  kl_div_with_sae: ['0.2152', '0.1594', '0.1821']
  kl_div_with_ablation: ['8.4321', '8.4164', '6.4943']
  kl_div_score: ['0.9745', '0.9811', '0.9720']
  ce_loss_with_sae: ['3.7403', '3.6995', '3.7435']
  ce_loss_without

We can see some interesting differences here among the SAEs. For example, the Apollo end-to-end SAE shows a much higher MSE between the SAE input and reconstruction, and also shows a significant change in L2 norm of the output vs. input. It also has anomalous explained variance and much lower input-output cosine sim. Yet, it performs the best in terms of KL divergence. These differences are probably due to the very different loss function that these SAEs were trained on.

We can also see that the OpenAI TopK SAE has much more sparsity than the other two SAEs (which entails lower L0 and L1 values), yet still performs similarly to the JB SAEs in terms of fidelity (CE loss score) and other model behavior preservation metrics. These SAEs were trained with a TopK activation function, which enforces the desired level of sparsity.

Ideally, we would like to have perfect model behavior preservation, a high degree of sparsity, perfect reconstruction quality, and so on. Unfortunately, in reality there are often tradeoffs between these desiderata. Nevertheless, sometimes new SAE methods advance the Pareto frontier, allowing us to find objectively somewhat better SAEs for our use case.

In this case, since the OpenAI TopK SAE performed quite well, we might choose this one for a future project.

## Visualizing Featurewise Metrics

In our evaluation process, we also collect metrics for individual SAE features. Let's go through these category by category, starting with feature density metrics. We collect two of these:
- Feature density itself, which we generally plot as log10 feature density
- Consistent activation heuristic, which helps identify multi-token features along with dense features in general.

Let's plot these first, and we'll then explain them.

### Feature Density Metrics

In [88]:
results_file = os.path.join(output_dir, "gpt2-small-res-jb-blocks.10.hook_resid_pre_128_Skylion007_openwebtext.json")
with open(results_file, "r") as f:
    results = json.load(f)

feature_stats = results["feature_metrics"]
feature_df = pd.DataFrame.from_dict(feature_stats)

feature_df["feature_name"] = [i for i in range(len(feature_df))]

log10_feature_density = np.log10(feature_df["feature_density"] + 1e-10)

px.histogram(
    feature_df, 
    x=log10_feature_density, 
    nbins=100, 
    title="Log10 Feature Density",
    labels={"x": "log10_feature_density"}  # Add this line
).show()

px.scatter(feature_df, 
           x=log10_feature_density,
           y = "consistent_activation_heuristic",
           marginal_y="histogram",
           marginal_x="histogram",
           hover_name="feature_name",
           title="Log10 Feature Density vs Consistent Activation Heuristic",
           labels={"x": "log10_feature_density", "y": "Consistent Activation Heuristic"}
           ).show() 

Let's now discuss how these are calculated and how to interpret them.

#### Log10 Feature Density

Feature density is initially calculated as:

```python
feature_density = total_feature_acts / total_tokens
```

The log10 transformation is then applied:

```python
log10_feature_density = np.log10(feature_density + 1e-10)
```

##### Significance

- **Compression of Range**: The log transformation compresses the range of values, making it easier to visualize and compare features with very different activation rates.
- **Pattern Identification**: It helps in identifying patterns across orders of magnitude, which is useful when dealing with sparse activations.
- **Error Prevention**: The small constant (1e-10) is added to avoid log(0) errors for features that never activate.

##### Interpretation

- Values close to 0: Features that activate for nearly every token. These are highly dense features and can be difficult to interpret (but this is not necessarily the case). Context features tend to be more dense than other features.
- More negative values: Increasingly sparse features. Dead features will show up as a bin of maximally negative features on the left side of the histogram. However, keep in mind that these are mere features that did not activate over the course of our evaluation--if too few prompts were included, more features will appear artificially "dead".
- The distribution of log10 feature density reveals the overall sparsity pattern of your SAE.

#### Consistent Activation Heuristic (CAH)

The Consistent Activation Heuristic is calculated as:

```python
consistent_activation_heuristic = total_feature_acts / total_feature_prompts
```

Where:
- `total_feature_acts`: The total number of times each feature was activated across all tokens.
- `total_feature_prompts`: The number of prompts in which each feature was activated at least once.

##### Significance

- **Measure of Consistency**: CAH indicates how consistently a feature activates when it's present in a prompt.
- **Feature Importance**: Higher values suggest features that are consistently important when relevant. For example, multi-token features and context features (e.g., the "French" feature) will have higher CAH scores.
- **Pattern Identification**: It can help identify features that capture specific, consistent patterns or concepts in the data.
- **Feature Categorization**: Helps distinguish between "specialist" features (high CAH, activated consistently when relevant) and more general or noise-like features (low CAH, activated inconsistently).

#### Combining CAH and Log10 Feature Density

When visualizing these metrics together (e.g., in a scatter plot), you can identify:

1. Dense, consistent features (high log10 density, high CAH)
2. Sparse, consistent features (low log10 density, high CAH)
3. Dense, inconsistent features (high log10 density, low CAH)
4. Sparse, inconsistent features (low log10 density, low CAH)

In [89]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# Create a DataFrame from the results
df = pd.DataFrame({
    "feature_id": [f"feature_{i}" for i in range(len(results["feature_metrics"]["encoder_decoder_cosine_sim"]))],
    "encoder_bias": results["feature_metrics"]["encoder_bias"],
    "encoder_norm": results["feature_metrics"]["encoder_norm"],
    "encoder_decoder_cosine_sim": results["feature_metrics"]["encoder_decoder_cosine_sim"]
})

# Function to create scatter matrix plot
def create_scatter_matrix(feature_df, title):
    fig = px.scatter_matrix(
        feature_df,
        dimensions=["encoder_bias", "encoder_norm", "encoder_decoder_cosine_sim"],
        hover_name="feature_id",
        title=title,
    )
    fig.update_layout(height=600, width=800)
    return fig

# Function to create encoder-decoder cosine similarity histogram
def create_cosine_sim_plot(feature_df, title):
    fig = px.histogram(feature_df, x="encoder_decoder_cosine_sim", title=title)
    fig.update_layout(height=400, width=600)
    return fig

# Create and display the scatter matrix plot
scatter_matrix = create_scatter_matrix(df, "Feature Scatter Matrix")
scatter_matrix.show()

# Create and display the encoder-decoder cosine similarity histogram
cosine_sim_hist = create_cosine_sim_plot(df, "Encoder-Decoder Cosine Similarity")
cosine_sim_hist.show()