### Project Proposal Summary: How good are SAEs, compared to other methods in their reference class?

Is a reconstruction MSE of 0.4 an impressive number? How does it compare to the top L0 PCA directions and K-means clustering with K=dvocab clusters? Are SAEs unique in archiving a good reconstruction? Maybe also answer the question for CE loss recovered.

Is the monosemanticity of SAE features unique? How do baselines (clustering, PCA, probing/steering directions) compare? Test this for max- and uniform activating dataset examples This is partially inspired by Szegedy et al. (2013) who find random MNIST directions to appear interpretable. We might also compare auto interpretability scores (if we can afford the time and API keys).

Is the activation distribution & sparsity of SAE latents of the form that we expect from theory & toy models? Maybe also look at things like cosine similarity distribution.



#### Application tasks: 

Consider the properties described in the summary. Which of these do you expect to differ between different SAE methodologies, how, and why? Pick 2-3 different SAE methods you’re familiar with, or consider e.g. Braun et al. 2024, Geo et al. 2024, Rajamanoharan et al. 2024.

Calculate & compare the MSE of two SAEs on the pile-10k dataset. Choose SAEs from different families (e.g. Geo et al. 2024 and Rajamanoharan et al. 2024).


1. Method 1: Sparse Autoencoder (Fully Connected)
    A simple, fully connected sparse autoencoder with L1 regularization for sparsity.
    Encourages sparsity directly in the latent space.
2. Method 2: Sparse Autoencoder with K-Means Latents
    A hybrid SAE that applies K-Means clustering on the latent space during training.
    Enforces cluster-like representations.

#### Define Metrics
FC-SAE has a slightly lower MSE (0.4370) than K-Means-SAE (0.4750).  FC-SAE directly minimizes reconstruction error with no additional constraints, focusing purely on sparsity and reconstruction. K-Means-SAE introduces a clustering constraint in the latent space, which may reduce reconstruction flexibility, leading to higher MSE.

FC-SAE latent space is unconstrained except for L1 sparsity, resulting in a dense and non-clustered representation.
K-Means-SAE enforces discrete clustering of latent vectors, which creates structured but less flexible representations.K-Means aligns latent features with cluster centroids, sacrificing reconstruction accuracy for interpretability.


FC-SAE explicitly enforces sparsity via L1 loss, leading to many latent dimensions being zero or near-zero.
K-Means-SAE may have less sparsity because its clustering objective doesn't inherently enforce sparsity. L1 regularization in FC-SAE penalizes non-zero activations, whereas K-Means focuses on clustering without sparsity constraints.

In [12]:
!pip install datasets



In [20]:
from datasets import load_dataset
from transformers import GPT2Tokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import KMeans

# Load pile-10k dataset
pile_10k = load_dataset("NeelNanda/pile-10k")

# Tokenize the dataset
# Load the GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Add a padding token
tokenizer.pad_token = tokenizer.eos_token  # Use the end-of-sequence token as the padding token


# Extract text samples from the dataset
text_samples = [sample["text"] for sample in pile_10k["train"]]

# Tokenize the dataset in a single batch
tokenized_batch = tokenizer(
    text_samples,
    return_tensors="pt",  # Return PyTorch tensors
    padding=True,         # Pad to the length of the longest sequence in the batch
    truncation=True,      # Truncate sequences longer than the model's max length
    max_length=512        # Set a fixed maximum length (optional)
)

# Convert tokenized inputs to input tensors
input_data = tokenized_batch["input_ids"].float()  # Shape: [batch_size, max_length]

# Normalize input data
input_data = (input_data - input_data.mean()) / input_data.std()


# # Convert tokenized data into input tensors
# input_data = torch.cat([sample["input_ids"].float() for sample in tokenized_data])
# input_data = (input_data - input_data.mean()) / input_data.std()  # Normalize


In [17]:
class SparseAutoencoderFC(nn.Module):
    def __init__(self, input_dim, latent_dim, sparsity_weight=1e-3):
        super(SparseAutoencoderFC, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, input_dim),
            nn.Sigmoid()
        )
        self.sparsity_weight = sparsity_weight

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

    def sparsity_loss(self, encoded):
        sparsity = torch.mean(torch.abs(encoded))
        return self.sparsity_weight * sparsity


class SparseAutoencoderKMeans(nn.Module):
    def __init__(self, input_dim, latent_dim, num_clusters):
        super(SparseAutoencoderKMeans, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, input_dim),
            nn.Sigmoid()
        )
        self.num_clusters = num_clusters
        self.kmeans = KMeans(n_clusters=num_clusters)

    def forward(self, x):
        encoded = self.encoder(x)
        # Apply K-means clustering in latent space
        cluster_assignments = self.kmeans.fit_predict(encoded.detach().cpu().numpy())
        cluster_centers = torch.tensor(self.kmeans.cluster_centers_, device=x.device)
        clustered_latents = cluster_centers[cluster_assignments]
        decoded = self.decoder(clustered_latents)
        return encoded, decoded


In [21]:
def train_sae(model, data, epochs=100, batch_size=32, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        total_loss = 0
        for i in range(0, len(data), batch_size):
            batch = data[i:i+batch_size]
            optimizer.zero_grad()
            encoded, decoded = model(batch)
            loss = criterion(decoded, batch) + (model.sparsity_loss(encoded) if hasattr(model, 'sparsity_loss') else 0)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(data):.4f}")

# Train both models
fc_autoencoder = SparseAutoencoderFC(input_dim=input_data.size(1), latent_dim=128)
train_sae(fc_autoencoder, input_data)

kmeans_autoencoder = SparseAutoencoderKMeans(input_dim=input_data.size(1), latent_dim=128, num_clusters=10)
train_sae(kmeans_autoencoder, input_data)


Epoch 1/100, Loss: 0.0156
Epoch 2/100, Loss: 0.0143
Epoch 3/100, Loss: 0.0142
Epoch 4/100, Loss: 0.0141
Epoch 5/100, Loss: 0.0141
Epoch 6/100, Loss: 0.0141
Epoch 7/100, Loss: 0.0140
Epoch 8/100, Loss: 0.0140
Epoch 9/100, Loss: 0.0140
Epoch 10/100, Loss: 0.0140
Epoch 11/100, Loss: 0.0140
Epoch 12/100, Loss: 0.0140
Epoch 13/100, Loss: 0.0140
Epoch 14/100, Loss: 0.0140
Epoch 15/100, Loss: 0.0140
Epoch 16/100, Loss: 0.0140
Epoch 17/100, Loss: 0.0140
Epoch 18/100, Loss: 0.0139
Epoch 19/100, Loss: 0.0139
Epoch 20/100, Loss: 0.0139
Epoch 21/100, Loss: 0.0139
Epoch 22/100, Loss: 0.0139
Epoch 23/100, Loss: 0.0139
Epoch 24/100, Loss: 0.0139
Epoch 25/100, Loss: 0.0139
Epoch 26/100, Loss: 0.0139
Epoch 27/100, Loss: 0.0139
Epoch 28/100, Loss: 0.0139
Epoch 29/100, Loss: 0.0139
Epoch 30/100, Loss: 0.0139
Epoch 31/100, Loss: 0.0139
Epoch 32/100, Loss: 0.0139
Epoch 33/100, Loss: 0.0139
Epoch 34/100, Loss: 0.0139
Epoch 35/100, Loss: 0.0139
Epoch 36/100, Loss: 0.0139
Epoch 37/100, Loss: 0.0139
Epoch 38/1

In [22]:
# Evaluate reconstruction MSE
def evaluate_mse(model, data):
    model.eval()
    criterion = nn.MSELoss()
    with torch.no_grad():
        _, decoded = model(data)
        mse = criterion(decoded, data)
    return mse.item()

mse_fc = evaluate_mse(fc_autoencoder, input_data)
mse_kmeans = evaluate_mse(kmeans_autoencoder, input_data)

print(f"MSE (Fully Connected SAE): {mse_fc:.4f}")
print(f"MSE (K-Means SAE): {mse_kmeans:.4f}")


MSE (Fully Connected SAE): 0.4370
MSE (K-Means SAE): 0.4750
