In [None]:
import sys

from sae_lens.training.evals import run_evals
from sae_lens.training.activations_store import ActivationsStore
import transformer_lens
from tests.unit.helpers import build_sae_cfg
from datasets import Dataset
from tqdm import tqdm
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import json

print(sys.version)
%load_ext autoreload
%autoreload 2

In [None]:
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

saes, sparsities = get_gpt2_res_jb_saes(device='mps')

In [None]:
model = transformer_lens.HookedTransformer.from_pretrained("gpt2-small", device="mps")
cfg = build_sae_cfg(
    checkpoint_path="./checkpoints",
    train_batch_size=1000,
    total_training_tokens=100000,
    context_size=72,
    device="mps",
    d_in=768,
)
dataset = Dataset.from_list([{"text": "hello world"}] * 100000)
activation_store = ActivationsStore.from_config(model, cfg, dataset=dataset)
res = {}
ctx = {
    "suffix": "",
    "n_training_steps": 100000,
}

for name, sae in tqdm(saes.items(), desc="Evaluating SAEs"):
    metrics = run_evals(sae, activation_store, model, ctx)
    res[name] = metrics.as_dict(suffix="")

# save metrics to file

with open("gpt2_small_metrics.json", "w") as f:
    json.dump(res, f)


In [None]:
res

In [None]:
%matplotlib inline

def create_heatmap(metrics_dict):
    # Convert the nested dictionary to a pandas DataFrame
    df = pd.DataFrame.from_dict(metrics_dict, orient='index')

    # Melt the DataFrame to reshape it for heatmap visualization
    df_melted = pd.melt(df.reset_index(), id_vars='index', var_name='Metric', value_name='Value')

    # Create a pivot table to reshape the data for the heatmap
    df_pivot = df_melted.pivot(index='Metric', columns='index', values='Value')

    # Normalize the values within each row for color scaling
    def normalize_row(row):
        if row.max() == row.min():
            return pd.Series(0.0, index=row.index)
        else:
            return (row - row.min()) / (row.max() - row.min())

    df_normalized = df_pivot.apply(normalize_row, axis=1)

    # Create a custom diverging color palette using vlag
    cmap = sns.color_palette("PuBu", as_cmap=True)

    # Set the figure size before creating the plot
    plt.figure(figsize=(24, 18))  # Adjust the values as needed

    # Create the heatmap using seaborn with the custom color palette and normalized values
    sns.heatmap(df_normalized, cmap=cmap, annot=df_pivot, fmt='.2f', cbar_kws={'label': 'Normalized Value'},
                annot_kws={"size": 10})

    # Set the plot title and labels
    plt.title('Metrics Heatmap')
    plt.xlabel('Instance')
    plt.ylabel('Metric')

    # Display the plot
    plt.show()

    # Save the plot
    plt.savefig("gpt2_small_metrics_heatmap.png")


with open("gpt2_small_metrics.json", "r") as f:
    metrics = json.load(f)
    create_heatmap(metrics)
