# Sparse Autoencoders Evaluation

This notebook evaluates the **Sparse Autoencoders (SAEs)** trained on different layers of the GPT & Prejudice model.  
Each SAE captures interpretable latent features from the hidden representations of the transformer, helping us analyze what types of information each layer encodes.

The goal of this evaluation is to:
- load the trained GPT model and its layer-wise SAEs,
- apply each SAE to hidden activations extracted from the model,
- measure **reconstruction performance** (how well the SAE preserves the original activations),

These evaluations form a baseline for later interpretability analyses, where the discovered sparse features are linked to human-understandable concepts (e.g., gender, emotion, class, or social role).

---

In [1]:
import tiktoken
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tiktoken

from utils.model import load_GPT_model
from utils.tokenization import text_to_token_ids
from utils.embeddings import get_token_embeddings_from_sentence
from sparse_auto_encoder import SparseAutoencoder

### 1. Setup and Model Loading

In [2]:
device = "cpu"

print(f"Using {device} device.")

Using cpu device.


In [3]:
model = load_GPT_model(path="model_896_14_8_256.pth", device=device)

In [4]:
tokenizer = tiktoken.get_encoding("gpt2")

### 2. Load Sparse Autoencoders

In [5]:
sae_1 = SparseAutoencoder(input_dim=896, hidden_dim=2688).to(device)
sae_1.load_state_dict(torch.load("sae_models/sae_layer1.pth", map_location=torch.device('cpu')))
sae_1.eval();

sae_2 = SparseAutoencoder(input_dim=896, hidden_dim=2688).to(device)
sae_2.load_state_dict(torch.load("sae_models/sae_layer2.pth", map_location=torch.device('cpu')))
sae_2.eval();

sae_3 = SparseAutoencoder(input_dim=896, hidden_dim=3584).to(device)
sae_3.load_state_dict(torch.load("sae_models/sae_layer3.pth", map_location=torch.device('cpu')))
sae_3.eval();

sae_4 = SparseAutoencoder(input_dim=896, hidden_dim=3584).to(device)
sae_4.load_state_dict(torch.load("sae_models/sae_layer4.pth", map_location=torch.device('cpu')))
sae_4.eval();

sae_5 = SparseAutoencoder(input_dim=896, hidden_dim=3584).to(device)
sae_5.load_state_dict(torch.load("sae_models/sae_layer5.pth", map_location=torch.device('cpu')))
sae_5.eval();

sae_6 = SparseAutoencoder(input_dim=896, hidden_dim=4480).to(device)
sae_6.load_state_dict(torch.load("sae_models/sae_layer6.pth", map_location=torch.device('cpu')))
sae_6.eval();

sae_7 = SparseAutoencoder(input_dim=896, hidden_dim=4480).to(device)
sae_7.load_state_dict(torch.load("sae_models/sae_layer7.pth", map_location=torch.device('cpu')))
sae_7.eval();

sae_8 = SparseAutoencoder(input_dim=896, hidden_dim=4480).to(device)
sae_8.load_state_dict(torch.load("sae_models/sae_layer8.pth", map_location=torch.device('cpu')))
sae_8.eval();

### 3. Layer-wise evaluation

The helper function `evaluate_trained_sae()` runs a quantitative evaluation for a given SAE and GPT layer.  
It performs the following steps:

- Extracts hidden activations from the specified layer of the GPT model  
- Passes them through the corresponding SAE for reconstruction  
- Computes metrics such as:
    - **Reconstruction MSE:** average squared error between original and reconstructed hidden states. Lower is better; captures information preserved in magnitude and direction.
    - **Average cosine similarity:** angular alignment of original vs. reconstructed vectors. Values close to 1 indicate the SAE preserves representational direction.
    - **Average L0 sparsity (active latents):** mean count of non-zero latent units per sample using a small threshold (|z| > 1e-5). Lower counts mean sparser, more selective features.
    - **Cross-entropy and KL proxy:** a lightweight stress test: a randomly-initialized linear “next-token head” is applied to original vs. reconstructed states. The change in cross-entropy (using the original argmax as targets) and the **KL divergence** between the two logit distributions approximate how much the SAE’s reconstruction could perturb a downstream classifier head.
  
These metrics reveal how well each SAE captures the structure of its target layer.

In [6]:
from evaluate_sae import evaluate_trained_sae
import re

In [7]:
evaluate_trained_sae(sae_1, model, layer=1)

Total sentences after filtering: 2636
Reconstruction MSE: 0.128097
Average Cosine Similarity: 0.704948
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 3.928008
Cross-Entropy Loss (reconstructed): 4.322168
KL Divergence (Reconstructed || Original): 0.022545


In [8]:
evaluate_trained_sae(sae_2, model, layer=2)

Total sentences after filtering: 2636
Reconstruction MSE: 0.200597
Average Cosine Similarity: 0.851549
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 3.547400
Cross-Entropy Loss (reconstructed): 3.929964
KL Divergence (Reconstructed || Original): 0.034600


In [9]:
evaluate_trained_sae(sae_3, model, layer=3)

Total sentences after filtering: 2636
Reconstruction MSE: 0.356255
Average Cosine Similarity: 0.877400
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 3.329621
Cross-Entropy Loss (reconstructed): 3.854797
KL Divergence (Reconstructed || Original): 0.061886


In [10]:
evaluate_trained_sae(sae_4, model, layer=4)

Total sentences after filtering: 2636
Reconstruction MSE: 0.505206
Average Cosine Similarity: 0.896534
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 2.807891
Cross-Entropy Loss (reconstructed): 3.289988
KL Divergence (Reconstructed || Original): 0.088757


In [11]:
evaluate_trained_sae(sae_5, model, layer=5)

Total sentences after filtering: 2636
Reconstruction MSE: 0.753457
Average Cosine Similarity: 0.877576
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 2.580509
Cross-Entropy Loss (reconstructed): 3.300773
KL Divergence (Reconstructed || Original): 0.123802


In [12]:
evaluate_trained_sae(sae_6, model, layer=6)

Total sentences after filtering: 2636
Reconstruction MSE: 1.077450
Average Cosine Similarity: 0.855100
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 2.332970
Cross-Entropy Loss (reconstructed): 3.226375
KL Divergence (Reconstructed || Original): 0.186150


In [13]:
evaluate_trained_sae(sae_7, model, layer=7)

Total sentences after filtering: 2636
Reconstruction MSE: 1.384787
Average Cosine Similarity: 0.829606
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 2.303382
Cross-Entropy Loss (reconstructed): 3.171908
KL Divergence (Reconstructed || Original): 0.207862


In [14]:
evaluate_trained_sae(sae_8, model, layer=8)

Total sentences after filtering: 2636
Reconstruction MSE: 1.673100
Average Cosine Similarity: 0.828214
Average L0 Sparsity (active latents): 50.00
Cross-Entropy Loss (original): 2.207709
Cross-Entropy Loss (reconstructed): 3.158854
KL Divergence (Reconstructed || Original): 0.265773
