# Evaluating SAEs with SAE Lens Evals

## Overview

1. 

## Imports & Installs

In [28]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import json
import numpy as np 

torch.set_grad_enabled(False);

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Using the Command Line Utility



In [15]:
import pandas as pd
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
df[df.release.str.contains("gpt2")] # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. 

Unnamed: 0,release,repo_id,model,saes_map,neuronpedia_id
gpt2-small-attn-out-v5-128k,gpt2-small-attn-out-v5-128k,jbloom/GPT2-Small-OAI-v5-128k-attn-out-SAEs,gpt2-small,"{'blocks.0.hook_attn_out': 'v5_128k_layer_0', ...",{'blocks.0.hook_attn_out': 'gpt2-small/0-att_1...
gpt2-small-attn-out-v5-32k,gpt2-small-attn-out-v5-32k,jbloom/GPT2-Small-OAI-v5-32k-attn-out-SAEs,gpt2-small,"{'blocks.0.hook_attn_out': 'v5_32k_layer_0', '...",{'blocks.0.hook_attn_out': 'gpt2-small/0-att_3...
gpt2-small-hook-z-kk,gpt2-small-hook-z-kk,ckkissane/attn-saes-gpt2-small-all-layers,gpt2-small,{'blocks.0.hook_z': 'gpt2-small_L0_Hcat_z_lr1....,"{'blocks.0.hook_z': 'gpt2-small/0-att-kk', 'bl..."
gpt2-small-mlp-out-v5-128k,gpt2-small-mlp-out-v5-128k,jbloom/GPT2-Small-OAI-v5-128k-mlp-out-SAEs,gpt2-small,"{'blocks.0.hook_mlp_out': 'v5_128k_layer_0', '...",{'blocks.0.hook_mlp_out': 'gpt2-small/0-mlp_12...
gpt2-small-mlp-out-v5-32k,gpt2-small-mlp-out-v5-32k,jbloom/GPT2-Small-OAI-v5-32k-mlp-out-SAEs,gpt2-small,"{'blocks.0.hook_mlp_out': 'v5_32k_layer_0', 'b...",{'blocks.0.hook_mlp_out': 'gpt2-small/0-mlp_32...
gpt2-small-mlp-tm,gpt2-small-mlp-tm,tommmcgrath/gpt2-small-mlp-out-saes,gpt2-small,{'blocks.0.hook_mlp_out': 'sae_group_gpt2_bloc...,"{'blocks.0.hook_mlp_out': None, 'blocks.1.hook..."
gpt2-small-res-jb,gpt2-small-res-jb,jbloom/GPT2-Small-SAEs-Reformatted,gpt2-small,{'blocks.0.hook_resid_pre': 'blocks.0.hook_res...,{'blocks.0.hook_resid_pre': 'gpt2-small/0-res-...
gpt2-small-res-jb-feature-splitting,gpt2-small-res-jb-feature-splitting,jbloom/GPT2-Small-Feature-Splitting-Experiment...,gpt2-small,{'blocks.8.hook_resid_pre_768': 'blocks.8.hook...,{'blocks.8.hook_resid_pre_768': 'gpt2-small/8-...
gpt2-small-res_sce-ajt,gpt2-small-res_sce-ajt,neuronpedia/gpt2-small__res_sce-ajt,gpt2-small,"{'blocks.2.hook_resid_pre': '2-res_sce-ajt', '...",{'blocks.2.hook_resid_pre': 'gpt2-small/2-res_...
gpt2-small-res_scefr-ajt,gpt2-small-res_scefr-ajt,neuronpedia/gpt2-small__res_scefr-ajt,gpt2-small,"{'blocks.2.hook_resid_pre': '2-res_scefr-ajt',...",{'blocks.2.hook_resid_pre': 'gpt2-small/2-res_...


Let's run evals for 3 SAEs:
- e2e
- openAI
- 

```bash
poetry run python sae_lens/evals.py "gpt2-small-res-jb" "blocks.9.*" \
    --batch_size_prompts 16 \
    --n_eval_sparsity_variance_batches 200 \
    --n_eval_reconstruction_batches 20 \
    --output_dir "demo_eval_results"
```

In [12]:
output_dir = "../demo_eval_results"

# list files in the output directory
print(os.listdir(output_dir))

['all_eval_results.json', 'gpt2-small-res-jb-blocks.9.hook_resid_pre_128_Skylion007_openwebtext.json', 'all_eval_results.csv']


** explanation of eval results schema. **


In [13]:

# get the first file in the output directory
first_file = os.path.join(output_dir, os.listdir(output_dir)[0])
print(first_file)

# load the results json pretty print it
with open(first_file, "r") as f:
    results = json.load(f)

print(json.dumps(results, indent=4))

../demo_eval_results/all_eval_results.json
[
    {
        "unique_id": "gpt2-small-res-jb-blocks.9.hook_resid_pre",
        "sae_set": "gpt2-small-res-jb",
        "sae_id": "blocks.9.hook_resid_pre",
        "eval_cfg": {
            "context_size": 128,
            "dataset": "Skylion007/openwebtext",
            "library_version": "3.21.1",
            "git_hash": "d398ed2",
            "batch_size_prompts": 16,
            "n_eval_reconstruction_batches": 10,
            "compute_kl": true,
            "compute_ce_loss": true,
            "n_eval_sparsity_variance_batches": 10,
            "compute_l2_norms": true,
            "compute_sparsity_metrics": true,
            "compute_variance_metrics": true,
            "compute_featurewise_density_statistics": true,
            "compute_featurewise_weight_based_metrics": true
        },
        "metrics": {
            "kl_div_with_sae": 0.1742093414068222,
            "kl_div_with_ablation": Infinity,
            "ce_loss_with_sae"

## Making Sense of Results

This section will be better if we run evals on multiple SAEs. 

## Visualizing Featurewise Metrics


Feature Density Statistics:
- Feature Density -> Feature Density Histogram
- Consistent Activation Heuristic -> Finding multi-token features.

In [39]:
first_file = os.path.join(output_dir, os.listdir(output_dir)[0])
print(first_file)

# load the results json pretty print it
with open(first_file, "r") as f:
    results = json.load(f)

feature_stats = results[0]["feature_metrics"]
feature_df = pd.DataFrame.from_dict(feature_stats)

feature_df["feature_name"] = [i for i in range(len(feature_df))]

log10_feature_density = np.log10(feature_df["feature_density"] + 1e-10)

px.histogram(feature_df, x=log10_feature_density, nbins=100).show() 

px.scatter(feature_df, 
           x=log10_feature_density,
           y = "consistent_activation_heuristic",
           marginal_y="histogram",
           marginal_x="histogram",
           hover_name="feature_name",
           ).show() 

../demo_eval_results/all_eval_results.json
