# Evaluating combo SAEs

In [None]:
from main import *

## Max activating features

## Linearity of SAE sizes

In [22]:
from huggingface_hub import hf_hub_download

# Define parameters
layer = 9
repo_id = 'charlieoneill/regular-sae'
filename = f'sae_layer_{layer}_32.pt'

# Load from HuggingFace
file_path = hf_hub_download(repo_id=repo_id, filename=filename)

big_sae = GatedSAE(768, 32*768, l1_coefficient=2)
big_sae.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))

<All keys matched successfully>

In [15]:
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# let's start with a layer 8 SAE.
hook_point = "blocks.8.hook_resid_pre"

# if the SAEs were stored with precomputed feature sparsities,
#  those will be return in a dictionary as well.
saes, sparsities = get_gpt2_res_jb_saes(hook_point)

sparse_autoencoder = saes[hook_point]
device = 'cpu'
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device

sparse_autoencoder.cfg.hook_point = "blocks.9.attn.hook_z"
sparse_autoencoder.cfg.store_batch_size = 4

print(sparse_autoencoder.cfg.store_batch_size)

loader = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg)

# don't overwrite the sparse autoencoder with the loader's sae (newly initialized)
model, _, activation_store = loader.load_sae_training_group_session()

100%|██████████| 1/1 [00:01<00:00,  1.02s/it]


4
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cpu


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [26]:
import einops

batch = activation_store.get_batch_tokens()

_, cache = model.run_with_cache(batch)

z_acts = cache['z', layer, 'attn']
print(z_acts.shape)
z_acts = einops.rearrange(z_acts, 'b h l d -> (b h) (l d)')
print(z_acts.shape)

sae_out, _, mse_loss = big_sae(z_acts, z_acts)
print(sae_out.shape)
print(mse_loss)

torch.Size([4, 128, 12, 64])
torch.Size([512, 768])
torch.Size([512, 768])
tensor(30.9679, grad_fn=<MeanBackward0>)


In [41]:
# Some sort of evaluation code

# I think we should have individual metric functions, then a function to apply them over a batch, then a function to apply them over a dataset

def mse_loss(x, y):
    """
    L2 loss of reconstruction.
    """
    per_item_loss = torch.nn.functional.mse_loss(x, y, reduction='none')
    return per_item_loss.sum(dim=-1).mean()

def l0_loss(z):
    """
    L0 loss of reconstruction.
    """
    return (z != 0).float().sum(dim=1).mean()

In [42]:
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name
from functools import partial
import torch
import torch.nn.functional as F

def calculate_kl_divergence(clean_logits, patched_logits):
    # Ensure the inputs are log probabilities
    clean_log_probs = F.log_softmax(clean_logits, dim=-1)
    patched_log_probs = F.log_softmax(patched_logits, dim=-1)
    
    # Convert patched_logits to probabilities
    patched_probs = torch.exp(patched_log_probs)
    
    # Calculate KL divergence for each element in the batch and sequence
    kl_div = F.kl_div(clean_log_probs, patched_probs, reduction='none')
    
    # Average over the vocabulary size (last dimension)
    kl_div = kl_div.sum(dim=-1)
    
    # Average over the batch and sequence length
    kl_div = kl_div.mean(dim=0).mean(dim=0)
    
    return kl_div.item()

def attention_head_z_patching_hook(attention_head_z, hook: HookPoint, layer: int, sae: GatedSAE, gated_sae: GatedSAE):
    z_acts = einops.rearrange(attention_head_z, "b s h d -> (b s) (h d)")
    if sae is not None:
        # Get the reconstructions from the SAE
        z_reconstruct, _, _ = sae(z_acts, z_acts)
    else:
        z_reconstruct = torch.zeros_like(z_acts)
    if gated_sae is not None:
        # Get the error
        error = z_acts - z_reconstruct
        # Get the predicted error
        predicted_error, _, _ = gated_sae(z_acts, error)
        # Add the predicted error to the z_reconstruct
        z_reconstruct = z_reconstruct + predicted_error
    # Rearrange back into original shape
    z_reconstruct = einops.rearrange(z_reconstruct, "(b s) (h d) -> b s h d", b=attention_head_z.shape[0], s=attention_head_z.shape[1], h=attention_head_z.shape[2], d=attention_head_z.shape[3])
    attention_head_z = z_reconstruct
    return attention_head_z


def kl_divergence_and_loss_difference(sae, gated_sae, batch, layer):
    clean_logits, clean_loss = model(batch, return_type="both")
    hook_fn = partial(attention_head_z_patching_hook, layer=layer, sae=sae, gated_sae=gated_sae)
    patched_logits, patched_loss = model.run_with_hooks(
        batch,
        fwd_hooks=[(get_act_name("z", layer, "attn"), hook_fn)],
        return_type="both"
    )
    return calculate_kl_divergence(clean_logits, patched_logits), patched_loss - clean_loss

# Apply our metrics over a single batch
kl_divergence, loss_difference = kl_divergence_and_loss_difference(big_sae, None, batch, layer)
z_hat, _, _ = big_sae(z_acts, z_acts)
z = big_sae.encoder(z_acts)
mse = mse_loss(z_hat, z_acts)
l0 = l0_loss(z)
print(f"KL Divergence: {kl_divergence}")
print(f"MSE Loss: {mse}")
print(f"L0 Loss: {l0}")

KL Divergence: 0.033432040363550186
MSE Loss: 30.967899322509766
L0 Loss: 12.9921875


### Training combo SAE

In [None]:
# Train a smaller regular SAE, hidden size 16_000

In [None]:
# Train a smaller error SAE, hidden size 16_000
layer = 9
model_type = 'gated'
n_epochs = 100
l1_coefficient = 3e-4
batch_size = 2048
lr = 0.001
projection_up = 16
repo_name = "error-saes"

error_sae = main(layer, model_type, n_epochs, l1_coefficient, projection_up, batch_size, lr, repo_name, return_model=True, save_model=False)