In [1]:
# --- Transformer Configuration ---
MODEL_NAME = "gpt2"
# We target the output projection of the MLP layer in a specific block
LAYER_IDX = 6
LAYER_NAME = f"transformer.h.{LAYER_IDX}.mlp.c_proj" # Output projection of MLP
# Alternative: f"transformer.h.{LAYER_IDX}.attn.c_proj" # Output projection of Attention
# Alternative: f"transformer.h.{LAYER_IDX}" # Output of the whole block (residual stream)

In [2]:
# --- Dataset Configuration ---
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-103-raw-v1" # Or "wikitext-2-raw-v1" for smaller
DATASET_SPLIT = "train"
NUM_ACTIVATIONS_TO_STORE = 2500000 # Max activations to collect (adjust based on RAM)
MAX_SEQ_LENGTH = 128 # Max token length for processing sequences

In [3]:
# --- SAE Configuration ---
ACTIVATION_DIM = 768 # Dimension of the activations we are capturing (GPT-2 small: 768)
SAE_EXPANSION_FACTOR = 4 # How many times larger the SAE hidden dim is than the activation dim
SAE_HIDDEN_DIM = ACTIVATION_DIM * SAE_EXPANSION_FACTOR
L1_COEFF = 8e-4 # Sparsity penalty strength

In [9]:
# --- Training Configuration ---
BATCH_SIZE = 256
LEARNING_RATE = 3e-4
NUM_EPOCHS = 200
PRINT_INTERVAL = 100

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm.auto import tqdm  # Use auto version for notebook compatibility
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
import os

In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU: NVIDIA RTX A1000 6GB Laptop GPU


In [7]:
print(f"Loading model: {MODEL_NAME}")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
model.eval()

print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Loading model: gpt2
Loading tokenizer: gpt2


In [8]:
activation_storage = []

# Register hook to save the activation
def get_activation_hook_no_flatten(layer_name):
    """Creates a hook function to capture activations."""
    def hook(model, input, output):
        act_tensor = output[0] if isinstance(output, tuple) else output
        activation_storage.append(act_tensor.detach().cpu().to(torch.float32))

    target_module = dict(model.named_modules())[layer_name]
    handle = target_module.register_forward_hook(hook)
    return handle

# Register the non-flattening hook
print(f"Registering hook on layer: {LAYER_NAME}")
hook_handle = get_activation_hook_no_flatten(LAYER_NAME)

activation_storage_filtered = []
token_storage = []

print(f"Loading dataset: {DATASET_NAME} ({DATASET_CONFIG})")
dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=f"{DATASET_SPLIT}", streaming=True)

print("Processing dataset and extracting activations (Aligning Tokens and Activations)...")
num_processed_tokens = 0 # Track actual non-padding tokens processed
pbar = tqdm(total=NUM_ACTIVATIONS_TO_STORE, desc="Extracting Aligned Activations/Tokens")

for example in iter(dataset):
    if num_processed_tokens >= NUM_ACTIVATIONS_TO_STORE:
        break

    text = example['text']
    if not text or len(text.strip()) == 0:
      continue

    # Tokenize the text
    inputs = tokenizer(text, return_tensors="pt", max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device) # Shape: (batch_size, seq_len)
    
    activation_storage.clear()

    with torch.no_grad():
        model(input_ids=input_ids, attention_mask=attention_mask)

    # --- Retrieve activation (assuming hook appended to global list) ---
    if not activation_storage:
         print("Warning: Hook did not capture activation for a batch. Skipping.")
         continue
        
    batch_activations = activation_storage[0]

    # Ensure shapes match for filtering
    if batch_activations.shape[0] != input_ids.shape[0] or batch_activations.shape[1] != input_ids.shape[1]:
       print(f"Warning: Activation shape {batch_activations.shape} mismatch with input_ids {input_ids.shape}. Check hook logic. Skipping batch.")
       continue

    # --- Filter Activations and Tokens based on Attention Mask ---
    cpu_input_ids = input_ids.cpu().numpy() # (batch_size, seq_len)
    cpu_attention_mask = attention_mask.cpu().numpy().astype(bool) # (batch_size, seq_len)

    # Flatten batch and sequence dimensions for filtering
    flat_activations = batch_activations.view(-1, ACTIVATION_DIM) # (batch*seq, hidden)
    flat_input_ids = cpu_input_ids.flatten() # (batch*seq,)
    flat_attention_mask = cpu_attention_mask.flatten() # (batch*seq,)

    # Select only items where attention_mask is True
    filtered_activations = flat_activations[flat_attention_mask]
    filtered_tokens = flat_input_ids[flat_attention_mask]

    # Store filtered tokens
    if filtered_activations.shape[0] > 0: # Only store if there are non-padding tokens
        activation_storage_filtered.append(filtered_activations)
        newly_added_tokens = filtered_tokens.tolist()
        token_storage.extend(newly_added_tokens)

        # --- Update Progress ---
        num_processed_tokens += len(newly_added_tokens)
        pbar.update(len(newly_added_tokens))

pbar.close()
hook_handle.remove() # Remove the hook

# --- Concatenate FINAL filtered activations ---
if not activation_storage_filtered:
    raise ValueError("No activations were extracted/stored after filtering. Check data or hook.")

all_activations = torch.cat(activation_storage_filtered, dim=0)
print(f"Total ALIGNED activations extracted: {all_activations.shape[0]}")
print(f"Total ALIGNED tokens stored: {len(token_storage)}")
print(f"Activation tensor final shape: {all_activations.shape}") # Should be (N, ACTIVATION_DIM)
print(f"Token storage final length: {len(token_storage)}")      # N should match above

# Check if the lengths match (they should now)
if all_activations.shape[0] != len(token_storage):
    print(f"CRITICAL WARNING: Mismatch between activation count ({all_activations.shape[0]}) and token count ({len(token_storage)}) AFTER filtering!")
else:
    print("Activation and token counts match. Alignment successful.")

# Normalize activations for SAE training
act_mean = torch.mean(all_activations, dim=0, keepdim=True)
act_norms = torch.norm(all_activations, dim=1, keepdim=True) + 1e-6 # Add epsilon for stability
all_activations = (all_activations - act_mean) / act_norms
print("Filtered activations normalized (unit norm per vector).")

# Convert to PyTorch DataLoader for training
activation_dataset = torch.utils.data.TensorDataset(all_activations)
activation_loader = torch.utils.data.DataLoader(activation_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Clean up memory
del activation_storage # Delete the temporary batch storage list
del activation_storage_filtered
import gc
gc.collect()
if device == torch.device("cuda"):
    torch.cuda.empty_cache()

Registering hook on layer: transformer.h.6.mlp.c_proj
Loading dataset: wikitext (wikitext-103-raw-v1)
Processing dataset and extracting activations (Aligning Tokens and Activations)...


Extracting Aligned Activations/Tokens:   0%|          | 0/2500000 [00:00<?, ?it/s]

Total ALIGNED activations extracted: 2500062
Total ALIGNED tokens stored: 2500062
Activation tensor final shape: torch.Size([2500062, 768])
Token storage final length: 2500062
Activation and token counts match. Alignment successful.
Filtered activations normalized (unit norm per vector).


In [10]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=True)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=True)

        self.relu = nn.ReLU()
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x):
        encoded = self.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return encoded, decoded

    def encode(self, x):
      return self.relu(self.encoder(x))

# Initialize SAE
print(f"Initializing SAE with ACTIVATION_DIM={ACTIVATION_DIM}, SAE_HIDDEN_DIM={SAE_HIDDEN_DIM}")
sae_model = SparseAutoencoder(ACTIVATION_DIM, SAE_HIDDEN_DIM).to(device)
print("SAE Model:")
print(sae_model)

Initializing SAE with ACTIVATION_DIM=768, SAE_HIDDEN_DIM=3072
SAE Model:
SparseAutoencoder(
  (encoder): Linear(in_features=768, out_features=3072, bias=True)
  (decoder): Linear(in_features=3072, out_features=768, bias=True)
  (relu): ReLU()
)


In [11]:
def sae_loss(original_activations, encoded_activations, decoded_activations, l1_coeff):
    """Calculates SAE loss: MSE + L1 penalty on encoded activations."""
    mse_loss = nn.functional.mse_loss(decoded_activations, original_activations)

    l1_loss = torch.mean(torch.abs(encoded_activations))

    total_loss = mse_loss + l1_coeff * l1_loss
    return total_loss, mse_loss, l1_loss

optimizer = optim.Adam(sae_model.parameters(), lr=LEARNING_RATE)

In [12]:
log_dir = "runs/sae_training_logs" # You can customize this path
os.makedirs(log_dir, exist_ok=True) # Create the directory if it doesn't exist
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print("Run 'tensorboard --logdir runs' (or your specific log_dir parent) to view logs.")

TensorBoard logs will be saved to: runs/sae_training_logs
Run 'tensorboard --logdir runs' (or your specific log_dir parent) to view logs.


In [None]:
print("Starting SAE training...")
sae_model.train()
training_losses = {'total': [], 'mse': [], 'l1': [], 'l0': []} # Track losses and L0

for epoch in range(NUM_EPOCHS):
    epoch_total_loss = 0
    epoch_mse_loss = 0
    epoch_l1_loss = 0
    epoch_l0_norm = 0 # Track average number of non-zero features

    pbar_train = tqdm(activation_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
    for i, batch in enumerate(pbar_train):
        # Batch is a list containing one tensor: the activations
        original_acts = batch[0].to(device)

        encoded, decoded = sae_model(original_acts)
        loss, mse, l1 = sae_loss(original_acts, encoded, decoded, L1_COEFF)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        epoch_total_loss += loss.item()
        epoch_mse_loss += mse.item()
        epoch_l1_loss += l1.item()
        
        l0 = torch.sum(encoded > 1e-6, dim=1).float().mean().item()
        epoch_l0_norm += l0

        if (i + 1) % PRINT_INTERVAL == 0 or i == len(activation_loader) - 1:
             pbar_train.set_postfix({
                 "Total Loss": f"{loss.item():.4f}",
                 "MSE": f"{mse.item():.4f}",
                 "L1": f"{l1.item():.4f}",
                 "L0": f"{l0:.2f}" # Avg active features in current batch
             })

    # Average losses for the epoch
    avg_total_loss = epoch_total_loss / len(activation_loader)
    avg_mse_loss = epoch_mse_loss / len(activation_loader)
    avg_l1_loss = epoch_l1_loss / len(activation_loader)
    avg_l0_norm = epoch_l0_norm / len(activation_loader)

    training_losses['total'].append(avg_total_loss)
    training_losses['mse'].append(avg_mse_loss)
    training_losses['l1'].append(avg_l1_loss)
    training_losses['l0'].append(avg_l0_norm)

    writer.add_scalar('Loss/Epoch/Total', avg_total_loss, epoch)
    writer.add_scalar('Loss/Epoch/MSE', avg_mse_loss, epoch)
    writer.add_scalar('Loss/Epoch/L1', avg_l1_loss, epoch)
    writer.add_scalar('L0_Norm/Epoch', avg_l0_norm, epoch)


    print(f"Epoch {epoch+1} Complete: Avg Total Loss={avg_total_loss:.12f}, Avg MSE={avg_mse_loss:.12f}, Avg L1={avg_l1_loss:.7f}, Avg L0={avg_l0_norm:.2f}")

print("SAE Training finished.")

Starting SAE training...


Epoch 1/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 1 Complete: Avg Total Loss=0.000032454870, Avg MSE=0.000014295274, Avg L1=0.0226995, Avg L0=2859.89


Epoch 2/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 2 Complete: Avg Total Loss=0.000014390422, Avg MSE=0.000002902213, Avg L1=0.0143603, Avg L0=2919.89


Epoch 3/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 3 Complete: Avg Total Loss=0.000011679091, Avg MSE=0.000002503595, Avg L1=0.0114694, Avg L0=2834.44


Epoch 4/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 4 Complete: Avg Total Loss=0.000010301931, Avg MSE=0.000002311449, Avg L1=0.0099881, Avg L0=2623.01


Epoch 5/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 5 Complete: Avg Total Loss=0.000009438556, Avg MSE=0.000002199259, Avg L1=0.0090491, Avg L0=2444.50


Epoch 6/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 6 Complete: Avg Total Loss=0.000008854756, Avg MSE=0.000002147218, Avg L1=0.0083844, Avg L0=2323.80


Epoch 7/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 7 Complete: Avg Total Loss=0.000008413719, Avg MSE=0.000002102576, Avg L1=0.0078889, Avg L0=2215.28


Epoch 8/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 8 Complete: Avg Total Loss=0.000008073218, Avg MSE=0.000002074949, Avg L1=0.0074978, Avg L0=2134.57


Epoch 9/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 9 Complete: Avg Total Loss=0.000007807023, Avg MSE=0.000002062174, Avg L1=0.0071811, Avg L0=2067.14


Epoch 10/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 10 Complete: Avg Total Loss=0.000007578985, Avg MSE=0.000002043618, Avg L1=0.0069192, Avg L0=2001.25


Epoch 11/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 11 Complete: Avg Total Loss=0.000007390869, Avg MSE=0.000002033702, Avg L1=0.0066965, Avg L0=1946.03


Epoch 12/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 12 Complete: Avg Total Loss=0.000007229329, Avg MSE=0.000002025989, Avg L1=0.0065042, Avg L0=1901.44


Epoch 13/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 13 Complete: Avg Total Loss=0.000007095955, Avg MSE=0.000002027544, Avg L1=0.0063355, Avg L0=1869.04


Epoch 14/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 14 Complete: Avg Total Loss=0.000006975428, Avg MSE=0.000002023720, Avg L1=0.0061896, Avg L0=1836.17


Epoch 15/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 15 Complete: Avg Total Loss=0.000006865273, Avg MSE=0.000002016465, Avg L1=0.0060610, Avg L0=1797.11


Epoch 16/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 16 Complete: Avg Total Loss=0.000006772795, Avg MSE=0.000002017684, Avg L1=0.0059439, Avg L0=1772.51


Epoch 17/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 17 Complete: Avg Total Loss=0.000006672633, Avg MSE=0.000002000545, Avg L1=0.0058401, Avg L0=1739.63


Epoch 18/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 18 Complete: Avg Total Loss=0.000006599555, Avg MSE=0.000002005057, Avg L1=0.0057431, Avg L0=1719.09


Epoch 19/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 19 Complete: Avg Total Loss=0.000006525794, Avg MSE=0.000002001235, Avg L1=0.0056557, Avg L0=1694.65


Epoch 20/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 20 Complete: Avg Total Loss=0.000006453759, Avg MSE=0.000001993321, Avg L1=0.0055755, Avg L0=1667.93


Epoch 21/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 21 Complete: Avg Total Loss=0.000006382640, Avg MSE=0.000001982371, Avg L1=0.0055003, Avg L0=1641.22


Epoch 22/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 22 Complete: Avg Total Loss=0.000006317405, Avg MSE=0.000001973707, Avg L1=0.0054296, Avg L0=1619.71


Epoch 23/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 23 Complete: Avg Total Loss=0.000006259505, Avg MSE=0.000001968632, Avg L1=0.0053636, Avg L0=1599.34


Epoch 24/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 24 Complete: Avg Total Loss=0.000006205349, Avg MSE=0.000001963989, Avg L1=0.0053017, Avg L0=1580.59


Epoch 25/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 25 Complete: Avg Total Loss=0.000006152746, Avg MSE=0.000001957374, Avg L1=0.0052442, Avg L0=1564.16


Epoch 26/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 26 Complete: Avg Total Loss=0.000006103201, Avg MSE=0.000001951437, Avg L1=0.0051897, Avg L0=1545.72


Epoch 27/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 27 Complete: Avg Total Loss=0.000006057123, Avg MSE=0.000001946513, Avg L1=0.0051383, Avg L0=1526.86


Epoch 28/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 28 Complete: Avg Total Loss=0.000006009820, Avg MSE=0.000001938824, Avg L1=0.0050887, Avg L0=1511.56


Epoch 29/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 29 Complete: Avg Total Loss=0.000005961686, Avg MSE=0.000001929196, Avg L1=0.0050406, Avg L0=1494.58


Epoch 30/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 30 Complete: Avg Total Loss=0.000005918412, Avg MSE=0.000001922111, Avg L1=0.0049954, Avg L0=1481.68


Epoch 31/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 31 Complete: Avg Total Loss=0.000005884681, Avg MSE=0.000001922215, Avg L1=0.0049531, Avg L0=1468.46


Epoch 32/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 32 Complete: Avg Total Loss=0.000005842242, Avg MSE=0.000001911418, Avg L1=0.0049135, Avg L0=1454.21


Epoch 33/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 33 Complete: Avg Total Loss=0.000005811642, Avg MSE=0.000001911960, Avg L1=0.0048746, Avg L0=1442.25


Epoch 34/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 34 Complete: Avg Total Loss=0.000005773141, Avg MSE=0.000001902505, Avg L1=0.0048383, Avg L0=1429.45


Epoch 35/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 35 Complete: Avg Total Loss=0.000005745101, Avg MSE=0.000001902605, Avg L1=0.0048031, Avg L0=1417.26


Epoch 36/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 36 Complete: Avg Total Loss=0.000005709488, Avg MSE=0.000001894077, Avg L1=0.0047693, Avg L0=1403.72


Epoch 37/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 37 Complete: Avg Total Loss=0.000005677817, Avg MSE=0.000001888994, Avg L1=0.0047360, Avg L0=1393.03


Epoch 38/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 38 Complete: Avg Total Loss=0.000005642983, Avg MSE=0.000001879729, Avg L1=0.0047041, Avg L0=1384.75


Epoch 39/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 39 Complete: Avg Total Loss=0.000005621491, Avg MSE=0.000001882844, Avg L1=0.0046733, Avg L0=1375.07


Epoch 40/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 40 Complete: Avg Total Loss=0.000005594294, Avg MSE=0.000001878868, Avg L1=0.0046443, Avg L0=1366.94


Epoch 41/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 41 Complete: Avg Total Loss=0.000005563622, Avg MSE=0.000001870120, Avg L1=0.0046169, Avg L0=1357.22


Epoch 42/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 42 Complete: Avg Total Loss=0.000005532860, Avg MSE=0.000001861029, Avg L1=0.0045898, Avg L0=1344.73


Epoch 43/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 43 Complete: Avg Total Loss=0.000005511486, Avg MSE=0.000001861522, Avg L1=0.0045625, Avg L0=1338.81


Epoch 44/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 44 Complete: Avg Total Loss=0.000005488907, Avg MSE=0.000001859024, Avg L1=0.0045374, Avg L0=1325.59


Epoch 45/200:   0%|          | 0/9766 [00:00<?, ?it/s]

Epoch 45 Complete: Avg Total Loss=0.000005449023, Avg MSE=0.000001838905, Avg L1=0.0045126, Avg L0=1307.67


Epoch 46/200:   0%|          | 0/9766 [00:00<?, ?it/s]

In [None]:
sae_model.eval()

print("Recomputing SAE features for analysis...")
with torch.no_grad():
    all_sae_features = []
    analysis_batch_size = 128
    for i in range(0, all_activations.shape[0], analysis_batch_size):
        batch_acts = all_activations[i:i+analysis_batch_size].to(device)
        batch_features = sae_model.encode(batch_acts).cpu()
        all_sae_features.append(batch_features)

    all_sae_features = torch.cat(all_sae_features, dim=0)

print(f"SAE features shape: {all_sae_features.shape}")

def get_max_activating_examples(feature_index, num_examples=10, window_size=10):
    """Find text snippets where a specific SAE feature maximally activates."""
    if feature_index < 0 or feature_index >= SAE_HIDDEN_DIM:
        print(f"Error: Feature index {feature_index} out of bounds (0-{SAE_HIDDEN_DIM-1}).")
        return
    if all_sae_features.shape[0] == 0:
        print("Error: No SAE features computed.")
        return
    if not token_storage:
        print("Error: Token storage is empty.")
        return

    print(f"\n--- Finding top activating examples for Feature {feature_index} ---")

    feature_activations = all_sae_features[:, feature_index]

    top_k_values, top_k_indices = torch.topk(feature_activations, k=min(num_examples, all_sae_features.shape[0]), largest=True) # Handle cases with fewer activations than requested

    for rank, (idx_tensor, activation_value_tensor) in enumerate(zip(top_k_indices, top_k_values)):
        activation_value = activation_value_tensor.item()
        global_token_idx = idx_tensor.item()

        if global_token_idx < 0 or global_token_idx >= len(token_storage):
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error - Calculated token index {global_token_idx} is out of bounds for token_storage (len={len(token_storage)}). Skipping.")
             continue
        
        start = max(0, global_token_idx - window_size)
        end = min(len(token_storage), global_token_idx + window_size + 1)

        if start >= end:
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error - Invalid window slice indices (start={start}, end={end}). Skipping.")
             continue

        token_ids_window = token_storage[start:end]

        try:
             context_text = tokenizer.decode(token_ids_window, skip_special_tokens=True, errors='replace')

             activating_token_id = token_storage[global_token_idx]
             activating_token_str = tokenizer.decode([activating_token_id], skip_special_tokens=False, errors='replace')


             print(f"\nRank {rank+1}: Activation = {activation_value:.10e}")
             print(f"Context: ...{context_text}...")
             print(f"(Activated on token approx position {global_token_idx-start}: '{activating_token_str}' [{activating_token_id}])")


        except IndexError:
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error decoding tokens - IndexError occurred with indices ({start}, {end}) for token_storage (len={len(token_storage)})")
        except Exception as e:
            print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
            print(f"Context: Error decoding tokens: {token_ids_window} - {e}")

Recomputing SAE features for analysis...


In [13]:
for i in range(10):
    random_feature_idx = random.randint(0, SAE_HIDDEN_DIM - 1)
    get_max_activating_examples(feature_index=random_feature_idx, num_examples=10, window_size=15)


--- Finding top activating examples for Feature 2130 ---

Rank 1: Activation = 3.8196775131e-03
Context: ... 1921 , at Kongens gate 29 until 1932 , and at Tollbodgaten 24 until 1938 . Management wanted to centralise both a new depot and...
(Activated on token approx position 15: 'g' [70])

Rank 2: Activation = 3.5598464310e-03
Context: ...ain " ; and " to have power , or ability . " Qidr , a noun derived from the same root , means " cauldron , kettle...
(Activated on token approx position 15: 'r' [81])

Rank 3: Activation = 3.5460712388e-03
Context: ... : 377 – 378 . 
 = = = Cited sources = = = 
 Colley , Ann C. ( 2010 ) . Victorians in...
(Activated on token approx position 15: ' =' [796])

Rank 4: Activation = 3.3055664971e-03
Context: ...umbling after " . The reference to " Jill " ( actually a " gill " , or The suggestion has also been made that Jack and Jill represent Louis...
(Activated on token approx position 15: 'ill' [359])

Rank 5: Activation = 3.1611532904e-03
Context: ... 

In [None]:
print("DEBUG: Inspecting sample SAE features...")
sae_model.eval()
sae_model.to(device) # Ensure model is on GPU

# Use a subset of activations to avoid RAM issues if all_activations is huge
num_inspect_samples = min(10000, all_activations.shape[0]) # Inspect up to 10k samples
sample_indices = torch.randperm(all_activations.shape[0])[:num_inspect_samples]

# Keep activations on CPU for this inspection if they fit, otherwise batch
sample_acts_cpu = all_activations[sample_indices]

# Calculate features in batches to avoid OOM on GPU if necessary
batch_size_inspect = 1024
all_sample_features_list = []
with torch.no_grad():
    for i in range(0, num_inspect_samples, batch_size_inspect):
        batch_acts_gpu = sample_acts_cpu[i:i+batch_size_inspect].to(device)
        batch_features_gpu = sae_model.encode(batch_acts_gpu)
        all_sample_features_list.append(batch_features_gpu.cpu()) # Move features to CPU

sample_features = torch.cat(all_sample_features_list, dim=0)

print(f"Sample features shape: {sample_features.shape}") # Should be (num_inspect_samples, SAE_HIDDEN_DIM)
print(f"Sample features min: {sample_features.min():.6f}")
print(f"Sample features max: {sample_features.max():.6f}")
print(f"Sample features mean: {sample_features.mean():.6f}")
print(f"Sample features std: {sample_features.std():.6f}")

# Calculate L0 norm for the sample (more accurate than just min/max/mean)
# Use a small threshold to account for floating point inaccuracies
active_features = (sample_features > 1e-6).float()
avg_l0_per_input = active_features.sum(dim=1).mean().item()
total_active_features = active_features.sum().item()
fraction_active = total_active_features / sample_features.numel()

print(f"\nSample L0 norm (avg active features per input): {avg_l0_per_input:.4f}")
print(f"Total active feature values in sample: {total_active_features}")
print(f"Fraction of active feature values: {fraction_active:.6f}")

# Check proportion of DEAD features (never activate across the sample)
features_never_activated = (active_features.sum(dim=0) == 0).sum().item()
print(f"\nNumber of features that NEVER activated in the sample: {features_never_activated} / {SAE_HIDDEN_DIM}")

del sample_acts_cpu, all_sample_features_list, sample_features, active_features # Cleanup
gc.collect()
if device == torch.device("cuda"):
    torch.cuda.empty_cache()

In [14]:
sae_model_path = "sae_gpt2_layer6_mlp_0.pth"
torch.save(sae_model.state_dict(), sae_model_path)
print(f"SAE model weights saved to: {sae_model_path}")

SAE model weights saved to: sae_gpt2_layer6_mlp_0.pth
