# Evaluating combo SAEs

In [None]:
from main import *

## Max activating features

## Linearity of SAE sizes

In [None]:
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
import einops
import torch
from huggingface_hub import hf_hub_download
from gated_sae import GatedSAE
from tqdm import tqdm

torch.set_grad_enabled(False)

def get_activation_store_and_model():

    # let's start with a layer 8 SAE.
    hook_point = "blocks.8.hook_resid_pre"

    # if the SAEs were stored with precomputed feature sparsities,
    #  those will be return in a dictionary as well.
    saes, sparsities = get_gpt2_res_jb_saes(hook_point)

    sparse_autoencoder = saes[hook_point]
    device = 'cpu'
    sparse_autoencoder.to(device)
    sparse_autoencoder.cfg.device = device

    sparse_autoencoder.cfg.hook_point = "blocks.9.attn.hook_z"
    sparse_autoencoder.cfg.store_batch_size = 4

    print(sparse_autoencoder.cfg.store_batch_size)

    loader = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg)

    # don't overwrite the sparse autoencoder with the loader's sae (newly initialized)
    model, _, activation_store = loader.load_sae_training_group_session()
    
    return model, activation_store

In [None]:
# Define parameters
layer = 9
repo_id = 'charlieoneill/regular-sae'


# Load big SAE
filename = f'sae_layer_{layer}_32.pt'
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
big_sae = GatedSAE(768, 32*768, l1_coefficient=2)
big_sae.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))

# Load little SAE
filename = f'sae_layer_{layer}_16.pt'
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
little_sae = GatedSAE(768, 16*768, l1_coefficient=2)
little_sae.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))

# Load error SAE
filename = f'sae_layer_{layer}.pt'
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
error_sae = GatedSAE(768, 16*768, l1_coefficient=2)
error_sae.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))

In [None]:
model, activation_store = get_activation_store_and_model()

In [None]:
batch = activation_store.get_batch_tokens()

_, cache = model.run_with_cache(batch)

z_acts = cache['z', layer, 'attn']
print(z_acts.shape)
z_acts = einops.rearrange(z_acts, 'b h l d -> (b h) (l d)')
print(z_acts.shape)

sae_out, _, mse_loss = big_sae(z_acts, z_acts)
print(sae_out.shape)
print(mse_loss)

In [None]:
# Some sort of evaluation code

# I think we should have individual metric functions, then a function to apply them over a batch, then a function to apply them over a dataset

def mse_loss(x, y):
    """
    L2 loss of reconstruction.
    """
    per_item_loss = torch.nn.functional.mse_loss(x, y, reduction='none')
    return per_item_loss.sum(dim=-1).mean()

def l0_loss(z):
    """
    L0 loss of reconstruction.
    """
    return (z != 0).float().sum(dim=1).mean()

In [None]:
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name
from functools import partial
import torch
import torch.nn.functional as F

def calculate_kl_divergence(clean_logits, patched_logits):
    # Ensure the inputs are log probabilities
    clean_log_probs = F.log_softmax(clean_logits, dim=-1)
    patched_log_probs = F.log_softmax(patched_logits, dim=-1)
    
    # Convert patched_logits to probabilities
    patched_probs = torch.exp(patched_log_probs)
    
    # Calculate KL divergence for each element in the batch and sequence
    kl_div = F.kl_div(clean_log_probs, patched_probs, reduction='none')
    
    # Average over the vocabulary size (last dimension)
    kl_div = kl_div.sum(dim=-1)
    
    # Average over the batch and sequence length
    kl_div = kl_div.mean(dim=0).mean(dim=0)
    
    return kl_div.item()

def attention_head_z_patching_hook(attention_head_z, hook: HookPoint, layer: int, sae: GatedSAE, gated_sae: GatedSAE):
    z_acts = einops.rearrange(attention_head_z, "b s h d -> (b s) (h d)")
    if sae is not None:
        # Get the reconstructions from the SAE
        z_reconstruct, _, _ = sae(z_acts, z_acts)
    else:
        z_reconstruct = torch.zeros_like(z_acts)
    if gated_sae is not None:
        # Get the error
        error = z_acts - z_reconstruct
        # Get the predicted error
        predicted_error, _, _ = gated_sae(z_acts, error)
        # Add the predicted error to the z_reconstruct
        z_reconstruct = z_reconstruct + predicted_error
    # Rearrange back into original shape
    z_reconstruct = einops.rearrange(z_reconstruct, "(b s) (h d) -> b s h d", b=attention_head_z.shape[0], s=attention_head_z.shape[1], h=attention_head_z.shape[2], d=attention_head_z.shape[3])
    attention_head_z = z_reconstruct
    return attention_head_z


def kl_divergence_and_loss_difference(sae, gated_sae, batch, layer):
    clean_logits, clean_loss = model(batch, return_type="both")
    hook_fn = partial(attention_head_z_patching_hook, layer=layer, sae=sae, gated_sae=gated_sae)
    patched_logits, patched_loss = model.run_with_hooks(
        batch,
        fwd_hooks=[(get_act_name("z", layer, "attn"), hook_fn)],
        return_type="both"
    )
    return calculate_kl_divergence(clean_logits, patched_logits), patched_loss - clean_loss

# Apply our metrics over a single batch
kl_divergence, loss_difference = kl_divergence_and_loss_difference(big_sae, None, batch, layer)
z_hat, _, _ = big_sae(z_acts, z_acts)
z = big_sae.encoder(z_acts)
mse = mse_loss(z_hat, z_acts)
l0 = l0_loss(z)
print(f"KL Divergence: {kl_divergence}")
print(f"MSE Loss: {mse}")
print(f"L0 Loss: {l0}")

In [None]:
# Now we need a function that gets the metrics for each model in a single batch

def metrics_from_batch(big_sae, little_sae, error_sae, batch, layer):

    # Get z_acts from batch
    _, cache = model.run_with_cache(batch)
    z_acts = cache['z', layer, 'attn']
    z_acts = einops.rearrange(z_acts, 'b h l d -> (b h) (l d)')

    # Just Big SAE
    kl_divergence, loss_difference = kl_divergence_and_loss_difference(big_sae, None, batch, layer)
    z_hat, _, _ = big_sae(z_acts, z_acts)
    z = big_sae.encoder(z_acts)
    mse = mse_loss(z_hat, z_acts)
    l0 = l0_loss(z)

    # Just little SAE
    kl_divergence_little, loss_difference_little = kl_divergence_and_loss_difference(little_sae, None, batch, layer)
    z_hat_little, _, _ = little_sae(z_acts, z_acts)
    z_little = little_sae.encoder(z_acts)
    mse_little = mse_loss(z_hat_little, z_acts)
    l0_little = l0_loss(z_little)

    # Combo SAE = Little SAE + Error SAE
    kl_divergence_combo, loss_difference_combo = kl_divergence_and_loss_difference(little_sae, error_sae, batch, layer)
    z_hat_combo, _, _ = little_sae(z_acts, z_acts)
    predicted_error, _, _ = error_sae(z_acts, z_acts - z_hat_combo)
    error_z = error_sae.encoder(z_acts)
    z_hat_combo = z_hat_combo + predicted_error
    z_combo = little_sae.encoder(z_acts) + error_z
    mse_combo = mse_loss(z_hat_combo, z_acts)
    l0_combo = l0_loss(z_combo)

    # Create dict
    batch_dict = {
        "big_sae": {
            "kl_divergence": kl_divergence,
            "loss_difference": loss_difference.item(), # "loss_difference" is the difference in loss between the patched and clean models
            "mse": mse.item(),
            "l0": l0.item()
        },
        "little_sae": {
            "kl_divergence": kl_divergence_little,
            "loss_difference": loss_difference_little.item(),
            "mse": mse_little.item(),
            "l0": l0_little.item()
        },
        "combo_sae": {
            "kl_divergence": kl_divergence_combo,
            "loss_difference": loss_difference_combo.item(),
            "mse": mse_combo.item(),
            "l0": l0_combo.item()
        }
    }

    return batch_dict

batch = activation_store.get_batch_tokens()

batch_dict = metrics_from_batch(big_sae, little_sae, error_sae, batch, layer)

# Print nicely
for model_name, model_dict in batch_dict.items():
    print(f"{model_name}:")
    for metric_name, metric_value in model_dict.items():
        print(f"\t{metric_name}: {metric_value}")


In [None]:
import numpy as np

# Now a function to apply over n batches and return the average
def metrics_from_batches(activation_store, big_sae, little_sae, error_sae, batch, layer, n_batches):
    batch_dicts = []
    for i in tqdm(range(n_batches)):
        batch = activation_store.get_batch_tokens()
        batch_dict = metrics_from_batch(big_sae, little_sae, error_sae, batch, layer)
        batch_dicts.append(batch_dict)

    average_dict = {}
    std_dict = {}
    for model_name in batch_dicts[0].keys():
        average_dict[model_name] = {}
        std_dict[model_name] = {}
        for metric_name in batch_dicts[0][model_name].keys():
            metric_values = [batch_dict[model_name][metric_name] for batch_dict in batch_dicts]
            average_dict[model_name][metric_name] = sum(metric_values) / n_batches
            std_dict[model_name][metric_name] = np.std(metric_values) / 2

            # If metric is kl_divergence, scale by 1000
            if metric_name == "kl_divergence" or metric_name == "loss_difference":
                average_dict[model_name][metric_name] *= 1000
                std_dict[model_name][metric_name] *= 1000

    return average_dict, std_dict

average_dict, std_dict = metrics_from_batches(activation_store, big_sae, little_sae, error_sae, batch, layer, 50)

# Print nicely
for model_name, model_dict in average_dict.items():
    print(f"{model_name}:")
    for metric_name, metric_value in model_dict.items():
        print(f"\t{metric_name}: {metric_value}")

In [None]:
import plotly.graph_objects as go
import plotly.express as px

# Create a grouped bar chat - groups are metrics, bars are models
# Create traces for each model
traces = []
for model_name, model_dict in average_dict.items():
    trace = go.Bar(
        name=model_name,
        x=list(model_dict.keys()),
        y=list(model_dict.values()),
        error_y=dict(
            type='data',
            array=[std_dict[model_name][metric_name] for metric_name in model_dict.keys()],
            visible=True
        )
    )
    traces.append(trace)

# Create the layout
layout = go.Layout(
    title="Big vs Combo metrics (same L0)",
    xaxis=dict(title="Metrics"),
    barmode="group",
    width=800
)

# Create the figure
fig = go.Figure(data=traces, layout=layout)

# Display the chart
fig.show()

In [None]:
# Pareto curve data (from HuggingFace)
l0s = [10, 11, 12, 13, 14, 15, 16]
l2s = [30.01, 29.02, 28.28, 27.99, 26.05, 25.80, 25.43]
big_sae = (12.38, 28.24)

# Create the line plot
fig = px.line(x=l0s, y=l2s, title="Combo SAE Pareto Curve (L0 vs MSE)", width=800)

# Add a red cross marker for big_sae
fig.add_trace(go.Scatter(
    x=[big_sae[0]],
    y=[big_sae[1]],
    mode='markers',
    marker=dict(
        color='red',
        size=15,
        symbol='cross'
    ),
    name='Big SAE'
))

# Update layout
fig.update_layout(
    xaxis_title='L0',
    yaxis_title='MSE',
    legend_title='Models',
    font=dict(size=14)
)

# Show the plot
fig.show()

### Training combo SAE

In [None]:
# Train a smaller regular SAE, hidden size 16_000

In [None]:
# Train a smaller error SAE, hidden size 16_000
layer = 9
model_type = 'gated'
n_epochs = 100
l1_coefficient = 3e-4
batch_size = 2048
lr = 0.001
projection_up = 16
repo_name = "error-saes"

error_sae = main(layer, model_type, n_epochs, l1_coefficient, projection_up, batch_size, lr, repo_name, return_model=True, save_model=False)