# Configuration

In [1]:
# --- Transformer Configuration ---
MODEL_NAME = "gpt2"
# We target the output projection of the MLP layer in a specific block
LAYER_IDX = 5
# LAYER_NAME = f"transformer.h.{LAYER_IDX}.mlp.c_proj" # Output projection of MLP
LAYER_NAME = f"transformer.h.{LAYER_IDX}.attn.c_proj" # Output projection of Attention
# Alternative: f"transformer.h.{LAYER_IDX}" # Output of the whole block (residual stream)

In [2]:
# --- Dataset Configuration ---
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-103-raw-v1" # Or "wikitext-2-raw-v1" for smaller
DATASET_SPLIT = "train"
NUM_ACTIVATIONS_TO_STORE = 2000000 # Max activations to collect (adjust based on RAM)
MAX_SEQ_LENGTH = 128 # Max token length for processing sequences

In [3]:
# --- SAE Configuration ---
ACTIVATION_DIM = 768 # Dimension of the activations we are capturing (GPT-2 small: 768)
SAE_EXPANSION_FACTOR = 4 # How many times larger the SAE hidden dim is than the activation dim
SAE_HIDDEN_DIM = ACTIVATION_DIM * SAE_EXPANSION_FACTOR
L1_COEFF = 8e-3 # Sparsity penalty strength

In [4]:
# --- Training Configuration ---
BATCH_SIZE = 256
LEARNING_RATE = 3e-4
NUM_EPOCHS = 300
PRINT_INTERVAL = 100

In [5]:
# --- Analysis Configuration ---
NUM_TOP_FEATURES_TO_ANALYZE = 50
ACTIVITY_THRESHOLD = 1e-6 # Threshold to consider a feature non-dead

# Setup

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm.auto import tqdm  # Use auto version for notebook compatibility
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
import os
import datetime

ModuleNotFoundError: No module named 'seaborn'

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU: NVIDIA RTX A1000 6GB Laptop GPU


In [8]:
print(f"Loading model: {MODEL_NAME}")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
model.eval()

print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Loading model: gpt2
Loading tokenizer: gpt2


# Data

In [9]:
activation_storage = []

# Register hook to save the activation
def get_activation_hook_no_flatten(layer_name):
    """Creates a hook function to capture activations."""
    def hook(model, input, output):
        act_tensor = output[0] if isinstance(output, tuple) else output
        activation_storage.append(act_tensor.detach().cpu().to(torch.float32))

    target_module = dict(model.named_modules())[layer_name]
    handle = target_module.register_forward_hook(hook)
    return handle

# Register the non-flattening hook
print(f"Registering hook on layer: {LAYER_NAME}")
hook_handle = get_activation_hook_no_flatten(LAYER_NAME)

activation_storage_filtered = []
token_storage = []

print(f"Loading dataset: {DATASET_NAME} ({DATASET_CONFIG})")
dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=f"{DATASET_SPLIT}", streaming=True)

print("Processing dataset and extracting activations (Aligning Tokens and Activations)...")
num_processed_tokens = 0 # Track actual non-padding tokens processed
pbar = tqdm(total=NUM_ACTIVATIONS_TO_STORE, desc="Extracting Aligned Activations/Tokens")

for example in iter(dataset):
    if num_processed_tokens >= NUM_ACTIVATIONS_TO_STORE:
        break

    text = example['text']
    if not text or len(text.strip()) == 0:
      continue

    # Tokenize the text
    inputs = tokenizer(text, return_tensors="pt", max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device) # Shape: (batch_size, seq_len)
    
    activation_storage.clear()

    with torch.no_grad():
        model(input_ids=input_ids, attention_mask=attention_mask)

    # --- Retrieve activation (assuming hook appended to global list) ---
    if not activation_storage:
         print("Warning: Hook did not capture activation for a batch. Skipping.")
         continue
        
    batch_activations = activation_storage[0]

    # Ensure shapes match for filtering
    if batch_activations.shape[0] != input_ids.shape[0] or batch_activations.shape[1] != input_ids.shape[1]:
       print(f"Warning: Activation shape {batch_activations.shape} mismatch with input_ids {input_ids.shape}. Check hook logic. Skipping batch.")
       continue

    # --- Filter Activations and Tokens based on Attention Mask ---
    cpu_input_ids = input_ids.cpu().numpy() # (batch_size, seq_len)
    cpu_attention_mask = attention_mask.cpu().numpy().astype(bool) # (batch_size, seq_len)

    # Flatten batch and sequence dimensions for filtering
    flat_activations = batch_activations.view(-1, ACTIVATION_DIM) # (batch*seq, hidden)
    flat_input_ids = cpu_input_ids.flatten() # (batch*seq,)
    flat_attention_mask = cpu_attention_mask.flatten() # (batch*seq,)

    # Select only items where attention_mask is True
    filtered_activations = flat_activations[flat_attention_mask]
    filtered_tokens = flat_input_ids[flat_attention_mask]

    # Store filtered tokens
    if filtered_activations.shape[0] > 0: # Only store if there are non-padding tokens
        activation_storage_filtered.append(filtered_activations)
        newly_added_tokens = filtered_tokens.tolist()
        token_storage.extend(newly_added_tokens)

        # --- Update Progress ---
        num_processed_tokens += len(newly_added_tokens)
        pbar.update(len(newly_added_tokens))

pbar.close()
hook_handle.remove() # Remove the hook

# --- Concatenate FINAL filtered activations ---
if not activation_storage_filtered:
    raise ValueError("No activations were extracted/stored after filtering. Check data or hook.")

all_activations = torch.cat(activation_storage_filtered, dim=0)
print(f"Total ALIGNED activations extracted: {all_activations.shape[0]}")
print(f"Total ALIGNED tokens stored: {len(token_storage)}")
print(f"Activation tensor final shape: {all_activations.shape}") # Should be (N, ACTIVATION_DIM)
print(f"Token storage final length: {len(token_storage)}")      # N should match above

# Check if the lengths match (they should now)
if all_activations.shape[0] != len(token_storage):
    print(f"CRITICAL WARNING: Mismatch between activation count ({all_activations.shape[0]}) and token count ({len(token_storage)}) AFTER filtering!")
else:
    print("Activation and token counts match. Alignment successful.")

# Normalize activations for SAE training
act_mean = torch.mean(all_activations, dim=0, keepdim=True)
act_norms = torch.norm(all_activations, dim=1, keepdim=True) + 1e-6 # Add epsilon for stability
all_activations = (all_activations - act_mean) / act_norms
print("Filtered activations normalized (unit norm per vector).")

# Convert to PyTorch DataLoader for training
activation_dataset = torch.utils.data.TensorDataset(all_activations)
activation_loader = torch.utils.data.DataLoader(activation_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Clean up memory
del activation_storage # Delete the temporary batch storage list
del activation_storage_filtered
import gc
gc.collect()
if device == torch.device("cuda"):
    torch.cuda.empty_cache()

Registering hook on layer: transformer.h.5.attn.c_proj
Loading dataset: wikitext (wikitext-103-raw-v1)
Processing dataset and extracting activations (Aligning Tokens and Activations)...


Extracting Aligned Activations/Tokens:   0%|          | 0/2000000 [00:00<?, ?it/s]

Total ALIGNED activations extracted: 2000012
Total ALIGNED tokens stored: 2000012
Activation tensor final shape: torch.Size([2000012, 768])
Token storage final length: 2000012
Activation and token counts match. Alignment successful.
Filtered activations normalized (unit norm per vector).


# Training

In [10]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=True)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=True)

        self.relu = nn.ReLU()
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x):
        encoded = self.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return encoded, decoded

    def encode(self, x):
      return self.relu(self.encoder(x))

# Initialize SAE
print(f"Initializing SAE with ACTIVATION_DIM={ACTIVATION_DIM}, SAE_HIDDEN_DIM={SAE_HIDDEN_DIM}")
sae_model = SparseAutoencoder(ACTIVATION_DIM, SAE_HIDDEN_DIM).to(device)
print("SAE Model:")
print(sae_model)

Initializing SAE with ACTIVATION_DIM=768, SAE_HIDDEN_DIM=3072
SAE Model:
SparseAutoencoder(
  (encoder): Linear(in_features=768, out_features=3072, bias=True)
  (decoder): Linear(in_features=3072, out_features=768, bias=True)
  (relu): ReLU()
)


In [11]:
def sae_loss(original_activations, encoded_activations, decoded_activations, l1_coeff):
    """Calculates SAE loss: MSE + L1 penalty on encoded activations."""
    mse_loss = nn.functional.mse_loss(decoded_activations, original_activations)

    l1_loss = torch.mean(torch.abs(encoded_activations))

    total_loss = mse_loss + l1_coeff * l1_loss
    return total_loss, mse_loss, l1_loss

optimizer = optim.Adam(sae_model.parameters(), lr=LEARNING_RATE)

In [12]:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = f"runs/{MODEL_NAME}_{LAYER_NAME}_sae_training_logs_{timestamp}" # You can customize this path
os.makedirs(log_dir, exist_ok=True) # Create the directory if it doesn't exist
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print("Run 'tensorboard --logdir runs' (or your specific log_dir parent) to view logs.")

TensorBoard logs will be saved to: runs/gpt2_transformer.h.5.attn.c_proj_sae_training_logs_20250416-160135
Run 'tensorboard --logdir runs' (or your specific log_dir parent) to view logs.


In [13]:
print("Starting SAE training...")
sae_model.train()
training_losses = {'total': [], 'mse': [], 'l1': [], 'l0': []} # Track losses and L0

for epoch in range(NUM_EPOCHS):
    epoch_total_loss = 0
    epoch_mse_loss = 0
    epoch_l1_loss = 0
    epoch_l0_norm = 0 # Track average number of non-zero features

    pbar_train = tqdm(activation_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
    for i, batch in enumerate(pbar_train):
        # Batch is a list containing one tensor: the activations
        original_acts = batch[0].to(device)

        encoded, decoded = sae_model(original_acts)
        loss, mse, l1 = sae_loss(original_acts, encoded, decoded, L1_COEFF)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        epoch_total_loss += loss.item()
        epoch_mse_loss += mse.item()
        epoch_l1_loss += l1.item()
        
        l0 = torch.sum(encoded > 1e-6, dim=1).float().mean().item()
        epoch_l0_norm += l0

        if (i + 1) % PRINT_INTERVAL == 0 or i == len(activation_loader) - 1:
             pbar_train.set_postfix({
                 "Total Loss": f"{loss.item():.4f}",
                 "MSE": f"{mse.item():.4f}",
                 "L1": f"{l1.item():.4f}",
                 "L0": f"{l0:.2f}" # Avg active features in current batch
             })

    # Average losses for the epoch
    avg_total_loss = epoch_total_loss / len(activation_loader)
    avg_mse_loss = epoch_mse_loss / len(activation_loader)
    avg_l1_loss = epoch_l1_loss / len(activation_loader)
    avg_l0_norm = epoch_l0_norm / len(activation_loader)

    training_losses['total'].append(avg_total_loss)
    training_losses['mse'].append(avg_mse_loss)
    training_losses['l1'].append(avg_l1_loss)
    training_losses['l0'].append(avg_l0_norm)

    writer.add_scalar('Loss/Epoch/Total', avg_total_loss, epoch)
    writer.add_scalar('Loss/Epoch/MSE', avg_mse_loss, epoch)
    writer.add_scalar('Loss/Epoch/L1', avg_l1_loss, epoch)
    writer.add_scalar('L0_Norm/Epoch', avg_l0_norm, epoch)


    print(f"Epoch {epoch+1} Complete: Avg Total Loss={avg_total_loss:.12f}, Avg MSE={avg_mse_loss:.12f}, Avg L1={avg_l1_loss:.7f}, Avg L0={avg_l0_norm:.2f}")

print("SAE Training finished.")

Starting SAE training...


Epoch 1/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 1 Complete: Avg Total Loss=0.000078709327, Avg MSE=0.000029620076, Avg L1=0.0061362, Avg L0=1432.10


Epoch 2/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 2 Complete: Avg Total Loss=0.000035934889, Avg MSE=0.000008642990, Avg L1=0.0034115, Avg L0=1227.31


Epoch 3/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 3 Complete: Avg Total Loss=0.000026363259, Avg MSE=0.000005990719, Avg L1=0.0025466, Avg L0=1123.83


Epoch 4/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 4 Complete: Avg Total Loss=0.000021334818, Avg MSE=0.000004753065, Avg L1=0.0020727, Avg L0=1048.71


Epoch 5/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 5 Complete: Avg Total Loss=0.000018348433, Avg MSE=0.000004090308, Avg L1=0.0017823, Avg L0=989.33


Epoch 6/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 6 Complete: Avg Total Loss=0.000016454921, Avg MSE=0.000003717888, Avg L1=0.0015921, Avg L0=943.92


Epoch 7/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 7 Complete: Avg Total Loss=0.000015123904, Avg MSE=0.000003476396, Avg L1=0.0014559, Avg L0=904.22


Epoch 8/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 8 Complete: Avg Total Loss=0.000014161098, Avg MSE=0.000003330113, Avg L1=0.0013539, Avg L0=877.34


Epoch 9/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 9 Complete: Avg Total Loss=0.000013438179, Avg MSE=0.000003233151, Avg L1=0.0012756, Avg L0=855.07


Epoch 10/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 10 Complete: Avg Total Loss=0.000012889892, Avg MSE=0.000003182852, Avg L1=0.0012134, Avg L0=837.98


Epoch 11/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 11 Complete: Avg Total Loss=0.000012463034, Avg MSE=0.000003164895, Avg L1=0.0011623, Avg L0=825.25


Epoch 12/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 12 Complete: Avg Total Loss=0.000012094407, Avg MSE=0.000003134300, Avg L1=0.0011200, Avg L0=810.54


Epoch 13/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 13 Complete: Avg Total Loss=0.000011788022, Avg MSE=0.000003114248, Avg L1=0.0010842, Avg L0=797.79


Epoch 14/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 14 Complete: Avg Total Loss=0.000011520189, Avg MSE=0.000003095162, Avg L1=0.0010531, Avg L0=787.70


Epoch 15/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 15 Complete: Avg Total Loss=0.000011296679, Avg MSE=0.000003090910, Avg L1=0.0010257, Avg L0=777.49


Epoch 16/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 16 Complete: Avg Total Loss=0.000011096193, Avg MSE=0.000003084989, Avg L1=0.0010014, Avg L0=771.35


Epoch 17/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 17 Complete: Avg Total Loss=0.000010921867, Avg MSE=0.000003082106, Avg L1=0.0009800, Avg L0=765.28


Epoch 18/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 18 Complete: Avg Total Loss=0.000010777486, Avg MSE=0.000003092849, Avg L1=0.0009606, Avg L0=761.33


Epoch 19/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 19 Complete: Avg Total Loss=0.000010644011, Avg MSE=0.000003097947, Avg L1=0.0009433, Avg L0=757.53


Epoch 20/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 20 Complete: Avg Total Loss=0.000010530567, Avg MSE=0.000003107360, Avg L1=0.0009279, Avg L0=752.82


Epoch 21/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 21 Complete: Avg Total Loss=0.000010428439, Avg MSE=0.000003116278, Avg L1=0.0009140, Avg L0=748.42


Epoch 22/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 22 Complete: Avg Total Loss=0.000010332330, Avg MSE=0.000003121209, Avg L1=0.0009014, Avg L0=744.45


Epoch 23/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 23 Complete: Avg Total Loss=0.000010244746, Avg MSE=0.000003125090, Avg L1=0.0008900, Avg L0=740.03


Epoch 24/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 24 Complete: Avg Total Loss=0.000010158841, Avg MSE=0.000003123699, Avg L1=0.0008794, Avg L0=736.69


Epoch 25/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 25 Complete: Avg Total Loss=0.000010094282, Avg MSE=0.000003137186, Avg L1=0.0008696, Avg L0=733.67


Epoch 26/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 26 Complete: Avg Total Loss=0.000010028765, Avg MSE=0.000003143419, Avg L1=0.0008607, Avg L0=730.93


Epoch 27/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 27 Complete: Avg Total Loss=0.000009951632, Avg MSE=0.000003133583, Avg L1=0.0008523, Avg L0=729.02


Epoch 28/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 28 Complete: Avg Total Loss=0.000009910145, Avg MSE=0.000003157156, Avg L1=0.0008441, Avg L0=727.55


Epoch 29/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 29 Complete: Avg Total Loss=0.000009845898, Avg MSE=0.000003152292, Avg L1=0.0008367, Avg L0=725.55


Epoch 30/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 30 Complete: Avg Total Loss=0.000009802183, Avg MSE=0.000003164130, Avg L1=0.0008298, Avg L0=723.66


Epoch 31/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 31 Complete: Avg Total Loss=0.000009749954, Avg MSE=0.000003163791, Avg L1=0.0008233, Avg L0=722.47


Epoch 32/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 32 Complete: Avg Total Loss=0.000009705634, Avg MSE=0.000003167659, Avg L1=0.0008172, Avg L0=721.02


Epoch 33/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 33 Complete: Avg Total Loss=0.000009680445, Avg MSE=0.000003187978, Avg L1=0.0008116, Avg L0=719.56


Epoch 34/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 34 Complete: Avg Total Loss=0.000009628917, Avg MSE=0.000003177951, Avg L1=0.0008064, Avg L0=717.80


Epoch 35/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 35 Complete: Avg Total Loss=0.000009601387, Avg MSE=0.000003190005, Avg L1=0.0008014, Avg L0=716.05


Epoch 36/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 36 Complete: Avg Total Loss=0.000009562624, Avg MSE=0.000003187751, Avg L1=0.0007969, Avg L0=714.22


Epoch 37/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 37 Complete: Avg Total Loss=0.000009520461, Avg MSE=0.000003180114, Avg L1=0.0007925, Avg L0=712.29


Epoch 38/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 38 Complete: Avg Total Loss=0.000009498475, Avg MSE=0.000003191390, Avg L1=0.0007884, Avg L0=710.71


Epoch 39/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 39 Complete: Avg Total Loss=0.000009467127, Avg MSE=0.000003191652, Avg L1=0.0007844, Avg L0=709.52


Epoch 40/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 40 Complete: Avg Total Loss=0.000009436073, Avg MSE=0.000003190502, Avg L1=0.0007807, Avg L0=708.30


Epoch 41/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 41 Complete: Avg Total Loss=0.000009411194, Avg MSE=0.000003193926, Avg L1=0.0007772, Avg L0=707.12


Epoch 42/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 42 Complete: Avg Total Loss=0.000009379498, Avg MSE=0.000003189544, Avg L1=0.0007737, Avg L0=705.94


Epoch 43/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 43 Complete: Avg Total Loss=0.000009363981, Avg MSE=0.000003199710, Avg L1=0.0007705, Avg L0=704.70


Epoch 44/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 44 Complete: Avg Total Loss=0.000009342489, Avg MSE=0.000003203010, Avg L1=0.0007674, Avg L0=703.71


Epoch 45/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 45 Complete: Avg Total Loss=0.000009314295, Avg MSE=0.000003198106, Avg L1=0.0007645, Avg L0=702.70


Epoch 46/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 46 Complete: Avg Total Loss=0.000009294095, Avg MSE=0.000003201415, Avg L1=0.0007616, Avg L0=702.01


Epoch 47/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 47 Complete: Avg Total Loss=0.000009261239, Avg MSE=0.000003189984, Avg L1=0.0007589, Avg L0=701.27


Epoch 48/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 48 Complete: Avg Total Loss=0.000009253034, Avg MSE=0.000003202592, Avg L1=0.0007563, Avg L0=700.64


Epoch 49/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 49 Complete: Avg Total Loss=0.000009236217, Avg MSE=0.000003205186, Avg L1=0.0007539, Avg L0=699.73


Epoch 50/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 50 Complete: Avg Total Loss=0.000009207867, Avg MSE=0.000003196118, Avg L1=0.0007515, Avg L0=699.18


Epoch 51/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 51 Complete: Avg Total Loss=0.000009193249, Avg MSE=0.000003200185, Avg L1=0.0007491, Avg L0=698.78


Epoch 52/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 52 Complete: Avg Total Loss=0.000009174514, Avg MSE=0.000003199024, Avg L1=0.0007469, Avg L0=698.25


Epoch 53/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 53 Complete: Avg Total Loss=0.000009161954, Avg MSE=0.000003203505, Avg L1=0.0007448, Avg L0=697.39


Epoch 54/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 54 Complete: Avg Total Loss=0.000009143444, Avg MSE=0.000003201250, Avg L1=0.0007428, Avg L0=696.46


Epoch 55/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 55 Complete: Avg Total Loss=0.000009125115, Avg MSE=0.000003198231, Avg L1=0.0007409, Avg L0=695.92


Epoch 56/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 56 Complete: Avg Total Loss=0.000009106121, Avg MSE=0.000003194361, Avg L1=0.0007390, Avg L0=695.56


Epoch 57/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 57 Complete: Avg Total Loss=0.000009099192, Avg MSE=0.000003202201, Avg L1=0.0007371, Avg L0=695.19


Epoch 58/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 58 Complete: Avg Total Loss=0.000009083834, Avg MSE=0.000003201522, Avg L1=0.0007353, Avg L0=694.79


Epoch 59/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 59 Complete: Avg Total Loss=0.000009062914, Avg MSE=0.000003194611, Avg L1=0.0007335, Avg L0=694.27


Epoch 60/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 60 Complete: Avg Total Loss=0.000009053387, Avg MSE=0.000003197998, Avg L1=0.0007319, Avg L0=693.62


Epoch 61/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 61 Complete: Avg Total Loss=0.000009041890, Avg MSE=0.000003199083, Avg L1=0.0007304, Avg L0=693.23


Epoch 62/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 62 Complete: Avg Total Loss=0.000009020982, Avg MSE=0.000003189839, Avg L1=0.0007289, Avg L0=692.71


Epoch 63/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 63 Complete: Avg Total Loss=0.000009016272, Avg MSE=0.000003196655, Avg L1=0.0007275, Avg L0=692.17


Epoch 64/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 64 Complete: Avg Total Loss=0.000009011759, Avg MSE=0.000003203216, Avg L1=0.0007261, Avg L0=691.83


Epoch 65/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 65 Complete: Avg Total Loss=0.000008986494, Avg MSE=0.000003188496, Avg L1=0.0007247, Avg L0=691.50


Epoch 66/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 66 Complete: Avg Total Loss=0.000008981171, Avg MSE=0.000003193777, Avg L1=0.0007234, Avg L0=691.15


Epoch 67/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 67 Complete: Avg Total Loss=0.000008964810, Avg MSE=0.000003187277, Avg L1=0.0007222, Avg L0=690.83


Epoch 68/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 68 Complete: Avg Total Loss=0.000008961979, Avg MSE=0.000003194481, Avg L1=0.0007209, Avg L0=690.39


Epoch 69/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 69 Complete: Avg Total Loss=0.000008953419, Avg MSE=0.000003195605, Avg L1=0.0007197, Avg L0=689.95


Epoch 70/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 70 Complete: Avg Total Loss=0.000008938746, Avg MSE=0.000003190231, Avg L1=0.0007186, Avg L0=689.44


Epoch 71/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 71 Complete: Avg Total Loss=0.000008929129, Avg MSE=0.000003189276, Avg L1=0.0007175, Avg L0=689.15


Epoch 72/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 72 Complete: Avg Total Loss=0.000008928891, Avg MSE=0.000003197934, Avg L1=0.0007164, Avg L0=688.82


Epoch 73/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 73 Complete: Avg Total Loss=0.000008910400, Avg MSE=0.000003187337, Avg L1=0.0007154, Avg L0=688.59


Epoch 74/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 74 Complete: Avg Total Loss=0.000008898994, Avg MSE=0.000003183530, Avg L1=0.0007144, Avg L0=688.30


Epoch 75/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 75 Complete: Avg Total Loss=0.000008897119, Avg MSE=0.000003189608, Avg L1=0.0007134, Avg L0=688.03


Epoch 76/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 76 Complete: Avg Total Loss=0.000008879182, Avg MSE=0.000003179137, Avg L1=0.0007125, Avg L0=687.78


Epoch 77/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 77 Complete: Avg Total Loss=0.000008879621, Avg MSE=0.000003186970, Avg L1=0.0007116, Avg L0=687.48


Epoch 78/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 78 Complete: Avg Total Loss=0.000008863599, Avg MSE=0.000003177975, Avg L1=0.0007107, Avg L0=687.23


Epoch 79/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 79 Complete: Avg Total Loss=0.000008862283, Avg MSE=0.000003183681, Avg L1=0.0007098, Avg L0=686.87


Epoch 80/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 80 Complete: Avg Total Loss=0.000008852982, Avg MSE=0.000003180665, Avg L1=0.0007090, Avg L0=686.40


Epoch 81/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 81 Complete: Avg Total Loss=0.000008850766, Avg MSE=0.000003185099, Avg L1=0.0007082, Avg L0=686.12


Epoch 82/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 82 Complete: Avg Total Loss=0.000008841060, Avg MSE=0.000003181426, Avg L1=0.0007075, Avg L0=685.87


Epoch 83/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 83 Complete: Avg Total Loss=0.000008827536, Avg MSE=0.000003173873, Avg L1=0.0007067, Avg L0=685.66


Epoch 84/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 84 Complete: Avg Total Loss=0.000008821864, Avg MSE=0.000003174084, Avg L1=0.0007060, Avg L0=685.42


Epoch 85/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 85 Complete: Avg Total Loss=0.000008823360, Avg MSE=0.000003181568, Avg L1=0.0007052, Avg L0=685.13


Epoch 86/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 86 Complete: Avg Total Loss=0.000008804752, Avg MSE=0.000003168686, Avg L1=0.0007045, Avg L0=684.96


Epoch 87/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 87 Complete: Avg Total Loss=0.000008812087, Avg MSE=0.000003181727, Avg L1=0.0007038, Avg L0=684.71


Epoch 88/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 88 Complete: Avg Total Loss=0.000008791279, Avg MSE=0.000003166127, Avg L1=0.0007031, Avg L0=684.57


Epoch 89/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 89 Complete: Avg Total Loss=0.000008796979, Avg MSE=0.000003177302, Avg L1=0.0007025, Avg L0=684.28


Epoch 90/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 90 Complete: Avg Total Loss=0.000008787337, Avg MSE=0.000003172578, Avg L1=0.0007018, Avg L0=684.11


Epoch 91/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 91 Complete: Avg Total Loss=0.000008769048, Avg MSE=0.000003159063, Avg L1=0.0007012, Avg L0=683.97


Epoch 92/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 92 Complete: Avg Total Loss=0.000008780229, Avg MSE=0.000003175451, Avg L1=0.0007006, Avg L0=683.68


Epoch 93/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 93 Complete: Avg Total Loss=0.000008769996, Avg MSE=0.000003169978, Avg L1=0.0007000, Avg L0=683.52


Epoch 94/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 94 Complete: Avg Total Loss=0.000008753009, Avg MSE=0.000003157508, Avg L1=0.0006994, Avg L0=683.32


Epoch 95/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 95 Complete: Avg Total Loss=0.000008757349, Avg MSE=0.000003166586, Avg L1=0.0006988, Avg L0=683.11


Epoch 96/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 96 Complete: Avg Total Loss=0.000008752147, Avg MSE=0.000003165919, Avg L1=0.0006983, Avg L0=682.89


Epoch 97/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 97 Complete: Avg Total Loss=0.000008746881, Avg MSE=0.000003164999, Avg L1=0.0006977, Avg L0=682.65


Epoch 98/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 98 Complete: Avg Total Loss=0.000008744671, Avg MSE=0.000003166989, Avg L1=0.0006972, Avg L0=682.44


Epoch 99/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 99 Complete: Avg Total Loss=0.000008733357, Avg MSE=0.000003159731, Avg L1=0.0006967, Avg L0=682.24


Epoch 100/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 100 Complete: Avg Total Loss=0.000008726504, Avg MSE=0.000003157182, Avg L1=0.0006962, Avg L0=681.94


Epoch 101/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 101 Complete: Avg Total Loss=0.000008726541, Avg MSE=0.000003161096, Avg L1=0.0006957, Avg L0=681.61


Epoch 102/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 102 Complete: Avg Total Loss=0.000008717104, Avg MSE=0.000003155522, Avg L1=0.0006952, Avg L0=681.33


Epoch 103/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 103 Complete: Avg Total Loss=0.000008710137, Avg MSE=0.000003152316, Avg L1=0.0006947, Avg L0=681.22


Epoch 104/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 104 Complete: Avg Total Loss=0.000008712326, Avg MSE=0.000003158497, Avg L1=0.0006942, Avg L0=681.06


Epoch 105/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 105 Complete: Avg Total Loss=0.000008702360, Avg MSE=0.000003152138, Avg L1=0.0006938, Avg L0=680.93


Epoch 106/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 106 Complete: Avg Total Loss=0.000008693484, Avg MSE=0.000003147039, Avg L1=0.0006933, Avg L0=680.80


Epoch 107/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 107 Complete: Avg Total Loss=0.000008695218, Avg MSE=0.000003152541, Avg L1=0.0006928, Avg L0=680.63


Epoch 108/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 108 Complete: Avg Total Loss=0.000008691925, Avg MSE=0.000003152604, Avg L1=0.0006924, Avg L0=680.47


Epoch 109/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 109 Complete: Avg Total Loss=0.000008682295, Avg MSE=0.000003146423, Avg L1=0.0006920, Avg L0=680.40


Epoch 110/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 110 Complete: Avg Total Loss=0.000008684078, Avg MSE=0.000003151637, Avg L1=0.0006916, Avg L0=680.25


Epoch 111/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 111 Complete: Avg Total Loss=0.000008665998, Avg MSE=0.000003136751, Avg L1=0.0006912, Avg L0=680.15


Epoch 112/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 112 Complete: Avg Total Loss=0.000008680747, Avg MSE=0.000003155153, Avg L1=0.0006907, Avg L0=679.91


Epoch 113/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 113 Complete: Avg Total Loss=0.000008669489, Avg MSE=0.000003146913, Avg L1=0.0006903, Avg L0=679.82


Epoch 114/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 114 Complete: Avg Total Loss=0.000008658617, Avg MSE=0.000003138503, Avg L1=0.0006900, Avg L0=679.86


Epoch 115/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 115 Complete: Avg Total Loss=0.000008660606, Avg MSE=0.000003143426, Avg L1=0.0006896, Avg L0=680.02


Epoch 116/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 116 Complete: Avg Total Loss=0.000008659912, Avg MSE=0.000003146127, Avg L1=0.0006892, Avg L0=679.75


Epoch 117/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 117 Complete: Avg Total Loss=0.000008646249, Avg MSE=0.000003135160, Avg L1=0.0006889, Avg L0=679.64


Epoch 118/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 118 Complete: Avg Total Loss=0.000008641429, Avg MSE=0.000003133463, Avg L1=0.0006885, Avg L0=679.52


Epoch 119/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 119 Complete: Avg Total Loss=0.000008648241, Avg MSE=0.000003143500, Avg L1=0.0006881, Avg L0=679.32


Epoch 120/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 120 Complete: Avg Total Loss=0.000008632201, Avg MSE=0.000003130219, Avg L1=0.0006877, Avg L0=679.26


Epoch 121/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 121 Complete: Avg Total Loss=0.000008642258, Avg MSE=0.000003143399, Avg L1=0.0006874, Avg L0=679.02


Epoch 122/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 122 Complete: Avg Total Loss=0.000008627846, Avg MSE=0.000003131695, Avg L1=0.0006870, Avg L0=678.92


Epoch 123/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 123 Complete: Avg Total Loss=0.000008631279, Avg MSE=0.000003137835, Avg L1=0.0006867, Avg L0=678.77


Epoch 124/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 124 Complete: Avg Total Loss=0.000008621965, Avg MSE=0.000003131176, Avg L1=0.0006863, Avg L0=678.67


Epoch 125/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 125 Complete: Avg Total Loss=0.000008624501, Avg MSE=0.000003136496, Avg L1=0.0006860, Avg L0=678.54


Epoch 126/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 126 Complete: Avg Total Loss=0.000008613019, Avg MSE=0.000003127378, Avg L1=0.0006857, Avg L0=678.46


Epoch 127/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 127 Complete: Avg Total Loss=0.000008611123, Avg MSE=0.000003128220, Avg L1=0.0006854, Avg L0=678.30


Epoch 128/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 128 Complete: Avg Total Loss=0.000008613224, Avg MSE=0.000003132906, Avg L1=0.0006850, Avg L0=678.19


Epoch 129/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 129 Complete: Avg Total Loss=0.000008602778, Avg MSE=0.000003124854, Avg L1=0.0006847, Avg L0=678.14


Epoch 130/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 130 Complete: Avg Total Loss=0.000008610989, Avg MSE=0.000003135732, Avg L1=0.0006844, Avg L0=677.99


Epoch 131/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 131 Complete: Avg Total Loss=0.000008606616, Avg MSE=0.000003133873, Avg L1=0.0006841, Avg L0=677.88


Epoch 132/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 132 Complete: Avg Total Loss=0.000008587617, Avg MSE=0.000003117228, Avg L1=0.0006838, Avg L0=677.83


Epoch 133/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 133 Complete: Avg Total Loss=0.000008596486, Avg MSE=0.000003129215, Avg L1=0.0006834, Avg L0=677.69


Epoch 134/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 134 Complete: Avg Total Loss=0.000008589962, Avg MSE=0.000003125932, Avg L1=0.0006830, Avg L0=677.60


Epoch 135/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 135 Complete: Avg Total Loss=0.000008574712, Avg MSE=0.000003113325, Avg L1=0.0006827, Avg L0=677.52


Epoch 136/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 136 Complete: Avg Total Loss=0.000008578503, Avg MSE=0.000003119935, Avg L1=0.0006823, Avg L0=677.40


Epoch 137/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 137 Complete: Avg Total Loss=0.000008579620, Avg MSE=0.000003123474, Avg L1=0.0006820, Avg L0=677.28


Epoch 138/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 138 Complete: Avg Total Loss=0.000008569308, Avg MSE=0.000003115941, Avg L1=0.0006817, Avg L0=677.13


Epoch 139/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 139 Complete: Avg Total Loss=0.000008563737, Avg MSE=0.000003112330, Avg L1=0.0006814, Avg L0=677.12


Epoch 140/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 140 Complete: Avg Total Loss=0.000008568938, Avg MSE=0.000003120149, Avg L1=0.0006811, Avg L0=676.95


Epoch 141/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 141 Complete: Avg Total Loss=0.000008557281, Avg MSE=0.000003110718, Avg L1=0.0006808, Avg L0=676.88


Epoch 142/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 142 Complete: Avg Total Loss=0.000008554187, Avg MSE=0.000003109782, Avg L1=0.0006806, Avg L0=676.83


Epoch 143/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 143 Complete: Avg Total Loss=0.000008560175, Avg MSE=0.000003118106, Avg L1=0.0006803, Avg L0=676.74


Epoch 144/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 144 Complete: Avg Total Loss=0.000008546610, Avg MSE=0.000003106616, Avg L1=0.0006800, Avg L0=676.68


Epoch 145/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 145 Complete: Avg Total Loss=0.000008544620, Avg MSE=0.000003106944, Avg L1=0.0006797, Avg L0=676.59


Epoch 146/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 146 Complete: Avg Total Loss=0.000008543373, Avg MSE=0.000003107956, Avg L1=0.0006794, Avg L0=676.49


Epoch 147/300:   0%|          | 0/7813 [00:00<?, ?it/s]

Epoch 147 Complete: Avg Total Loss=0.000008544741, Avg MSE=0.000003111602, Avg L1=0.0006791, Avg L0=676.39


Epoch 148/300:   0%|          | 0/7813 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Load model if you want to avoid training
# model_name = "sae_gpt2_layer6_mlp_1.pth"
# sae_model.load_state_dict(torch.load(model_name, map_location=device))

# Evaluation

In [14]:
sae_model.eval()

print("Recomputing SAE features for analysis...")
with torch.no_grad():
    all_sae_features = []
    analysis_batch_size = 128
    for i in range(0, all_activations.shape[0], analysis_batch_size):
        batch_acts = all_activations[i:i+analysis_batch_size].to(device)
        batch_features = sae_model.encode(batch_acts).cpu()
        all_sae_features.append(batch_features)

    all_sae_features = torch.cat(all_sae_features, dim=0)

print(f"SAE features shape: {all_sae_features.shape}")

def get_max_activating_examples(feature_index, num_examples=10, window_size=10):
    """Find text snippets where a specific SAE feature maximally activates."""
    if feature_index < 0 or feature_index >= SAE_HIDDEN_DIM:
        print(f"Error: Feature index {feature_index} out of bounds (0-{SAE_HIDDEN_DIM-1}).")
        return
    if all_sae_features.shape[0] == 0:
        print("Error: No SAE features computed.")
        return
    if not token_storage:
        print("Error: Token storage is empty.")
        return

    print(f"\n--- Finding top activating examples for Feature {feature_index} ---")

    feature_activations = all_sae_features[:, feature_index]

    top_k_values, top_k_indices = torch.topk(feature_activations, k=min(num_examples, all_sae_features.shape[0]), largest=True) # Handle cases with fewer activations than requested

    analysis_output = []
    
    for rank, (idx_tensor, activation_value_tensor) in enumerate(zip(top_k_indices, top_k_values)):
        token_dict = {}
        activation_value = activation_value_tensor.item()
        global_token_idx = idx_tensor.item()

        if global_token_idx < 0 or global_token_idx >= len(token_storage):
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error - Calculated token index {global_token_idx} is out of bounds for token_storage (len={len(token_storage)}). Skipping.")
             continue
        
        start = max(0, global_token_idx - window_size)
        end = min(len(token_storage), global_token_idx + window_size + 1)

        if start >= end:
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error - Invalid window slice indices (start={start}, end={end}). Skipping.")
             continue

        token_ids_window = token_storage[start:end]

        try:
             context_text = tokenizer.decode(token_ids_window, skip_special_tokens=True, errors='replace')

             activating_token_id = token_storage[global_token_idx]
             activating_token_str = tokenizer.decode([activating_token_id], skip_special_tokens=False, errors='replace')

             token_dict['token'] = activating_token_str
             token_dict['context'] = context_text
             token_dict['activation'] = activation_value

        except IndexError:
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error decoding tokens - IndexError occurred with indices ({start}, {end}) for token_storage (len={len(token_storage)})")
        except Exception as e:
            print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
            print(f"Context: Error decoding tokens: {token_ids_window} - {e}")
        analysis_output.append(token_dict)
    return analysis_output

Recomputing SAE features for analysis...
SAE features shape: torch.Size([2000012, 3072])


In [15]:
if 'all_sae_features' not in locals() or all_sae_features.shape[0] == 0:
    print("ERROR: `all_sae_features` tensor not found or is empty. Please run the previous cell first.")
else:
    print(f"Calculating statistics across {all_sae_features.shape[1]} features...")

    # --- Calculate Statistics Per Feature ---
    try:
        feature_max_activations, _ = torch.max(all_sae_features, dim=0)

        feature_mean_activations = torch.mean(all_sae_features, dim=0)

        feature_std_devs = torch.std(all_sae_features, dim=0)

        active_feature_mask = feature_max_activations > ACTIVITY_THRESHOLD # FILTER!!
        active_feature_indices = torch.where(active_feature_mask)[0]

        print(f"Found {len(active_feature_indices)} active features (max > {ACTIVITY_THRESHOLD}).")

        top_max_k_values, top_max_k_indices = torch.topk(
            feature_max_activations,
            k=min(NUM_TOP_FEATURES_TO_ANALYZE, all_sae_features.shape[1])
        )
        top_max_activation_indices = top_max_k_indices.tolist()
        print(f"\nTop {len(top_max_activation_indices)} features by MAX activation selected.")

        # --- Clean up intermediate tensors ---
        del feature_max_activations, feature_mean_activations, feature_std_devs
        del active_feature_mask, active_feature_indices
        gc.collect()


    except Exception as e:
        print(f"An error occurred during statistics calculation: {e}")
        import traceback
        traceback.print_exc()
        top_max_activation_indices = []
        top_std_dev_indices = []
        top_mean_activation_indices = []

Calculating statistics across 3072 features...
Found 1686 active features (max > 1e-06).

Top 50 features by MAX activation selected.


In [None]:
# Option 1: Maximum Activation (Recommended for seeing peak usage)
feature_strengths_np = torch.max(all_sae_features, dim=0)[0].numpy()
metric_label = "Max Activation"

# # Option 2: Mean Activation
# feature_strengths = torch.mean(all_sae_features, dim=0)
# metric_label = "Mean Activation"

# # Option 3: Activation Frequency (Approx L0 Norm)
# feature_activity = (all_sae_features > ACTIVITY_THRESHOLD).float() # ACTIVITY_THRESHOLD from config
# feature_strengths = torch.mean(feature_activity, dim=0)
# metric_label = f"Activation Frequency (>{ACTIVITY_THRESHOLD})"
# del feature_activity # Free up memory

print(f"Calculated feature strengths using: {metric_label}")
num_features = SAE_HIDDEN_DIM

def get_grid_dims(n):
    """Finds reasonably square grid dimensions for n items."""
    if n <= 0: return (0,0)
    sqrt_n = math.isqrt(n)
    rows = sqrt_n
    while n % rows != 0:
        rows -= 1
        if rows == 0:
             rows = int(sqrt_n)
             while n % rows != 0:
                 rows +=1
             break

    cols = n // rows
    return (min(rows, cols), max(rows, cols)) if rows > 0 else (1,n)

# 3072 = 64*48
if 64 * 48 == num_features:
     grid_rows, grid_cols = 48, 64
     print(f"Using specified grid: {grid_rows}x{grid_cols}")
else:
     grid_rows, grid_cols = get_grid_dims(num_features)
     if grid_rows * grid_cols > num_features:
          print(f"Calculated grid: {grid_rows}x{grid_cols}. Padding needed.")
          padding_size = grid_rows * grid_cols - num_features
          feature_strengths_np = np.pad(feature_strengths_np, (0, padding_size), mode='constant', constant_values=np.nan)
     else:
          print(f"Calculated grid: {grid_rows}x{grid_cols}.")


feature_grid = feature_strengths_np.reshape(grid_rows, grid_cols)

# --- Plotting ---

plt.figure(figsize=(12, 10))
plt.imshow(feature_grid, cmap='viridis', aspect='auto') # Alternative without seaborn
plt.colorbar(label=metric_label)
plt.title(f'SAE Feature Strength ({metric_label}) - {grid_rows}x{grid_cols} Grid')
plt.xlabel("Feature Index (Column Offset)")
plt.ylabel("Feature Index (Row Offset)")
plt.tight_layout()
plt.show()

try:
    from mpl_toolkits.mplot3d import Axes3D

    if num_features > 5000:
         print("\nSkipping 3D bar plot: Too many features (>5000), likely unreadable.")
    else:
         print("\nGenerating 3D bar plot (can be slow and cluttered)...")
         fig = plt.figure(figsize=(15, 10))
         ax = fig.add_subplot(111, projection='3d')

         xpos, ypos = np.meshgrid(np.arange(grid_cols), np.arange(grid_rows))
         xpos = xpos.flatten()
         ypos = ypos.flatten()
         zpos = np.zeros_like(xpos)

         dz = feature_grid.flatten()
         valid_bars = ~np.isnan(dz)

         colors = plt.cm.viridis(dz[valid_bars] / np.nanmax(dz))

         ax.bar3d(xpos[valid_bars], ypos[valid_bars], zpos[valid_bars],
                  dx=0.8, dy=0.8, dz=dz[valid_bars],
                  color=colors, # Use variable colors
                  shade=True)

         ax.set_xlabel('Feature Col Index')
         ax.set_ylabel('Feature Row Index')
         ax.set_zlabel(metric_label)
         ax.set_title(f'SAE Feature Strength ({metric_label}) - 3D Bars')
         plt.tight_layout()
         plt.show()

except ImportError:
    print("\nCould not import Axes3D for 3D plot. Skipping.")
except Exception as e:
    print(f"\nAn error occurred during 3D plot generation: {e}")

In [20]:
# Show token and context
result = []
for feature_idx in top_max_activation_indices:
    c = get_max_activating_examples(feature_index=feature_idx, num_examples=10, window_size=10)
    result.append(c)
    for i in c:
        print(f"\n{i['context']}")
        print(i['token'])


--- Finding top activating examples for Feature 3021 ---

 age and with my undistinguished track @-@ record , so I promptly fell in love with her
@

 : though Greaves described the Chelsea title @-@ winning side as " almost certainly one of the least
@

 : Aerosmith was the first band @-@ centric game for the series . On September 4
@

 episode more involved with insinuation and mythology @-@ building than with telling a complete @-@ in
@

 Society for Animal Rights decried the prison @-@ like conditions of the cages and called for changes .
@

 before Michigan 's NIT Season Tip @-@ Off game against the Pittsburgh Panthers . The shirts use
@

 definitive masala film , and a trend @-@ setter for " multi @-@ star "
@

i remarked on the campaign in the run @-@ up to the festival , saying " The 2006 edition
@

t a love child , but a bump @-@ and @-@ grind that never finds a groove
@

 of The Monkees , for the hour @-@ long video Elephant Parts ( also known as Michael N
@

--- Finding top 

# Debug

In [21]:
print("DEBUG: Inspecting sample SAE features...")
sae_model.eval()
sae_model.to(device) # Ensure model is on GPU

# Use a subset of activations to avoid RAM issues if all_activations is huge
num_inspect_samples = min(10000, all_activations.shape[0]) # Inspect up to 10k samples
sample_indices = torch.randperm(all_activations.shape[0])[:num_inspect_samples]

# Keep activations on CPU for this inspection if they fit, otherwise batch
sample_acts_cpu = all_activations[sample_indices]

# Calculate features in batches to avoid OOM on GPU if necessary
batch_size_inspect = 1024
all_sample_features_list = []
with torch.no_grad():
    for i in range(0, num_inspect_samples, batch_size_inspect):
        batch_acts_gpu = sample_acts_cpu[i:i+batch_size_inspect].to(device)
        batch_features_gpu = sae_model.encode(batch_acts_gpu)
        all_sample_features_list.append(batch_features_gpu.cpu()) # Move features to CPU

sample_features = torch.cat(all_sample_features_list, dim=0)

print(f"Sample features shape: {sample_features.shape}") # Should be (num_inspect_samples, SAE_HIDDEN_DIM)
print(f"Sample features min: {sample_features.min():.6f}")
print(f"Sample features max: {sample_features.max():.6f}")
print(f"Sample features mean: {sample_features.mean():.6f}")
print(f"Sample features std: {sample_features.std():.6f}")

# Calculate L0 norm for the sample (more accurate than just min/max/mean)
# Use a small threshold to account for floating point inaccuracies
active_features = (sample_features > 1e-6).float()
avg_l0_per_input = active_features.sum(dim=1).mean().item()
total_active_features = active_features.sum().item()
fraction_active = total_active_features / sample_features.numel()

print(f"\nSample L0 norm (avg active features per input): {avg_l0_per_input:.4f}")
print(f"Total active feature values in sample: {total_active_features}")
print(f"Fraction of active feature values: {fraction_active:.6f}")

# Check proportion of DEAD features (never activate across the sample)
features_never_activated = (active_features.sum(dim=0) == 0).sum().item()
print(f"\nNumber of features that NEVER activated in the sample: {features_never_activated} / {SAE_HIDDEN_DIM}")

del sample_acts_cpu, all_sample_features_list, sample_features, active_features # Cleanup
gc.collect()
if device == torch.device("cuda"):
    torch.cuda.empty_cache()

DEBUG: Inspecting sample SAE features...
Sample features shape: torch.Size([10000, 3072])
Sample features min: 0.000000
Sample features max: 0.100413
Sample features mean: 0.000679
Sample features std: 0.002022

Sample L0 norm (avg active features per input): 675.3610
Total active feature values in sample: 6753610.0
Fraction of active feature values: 0.219844

Number of features that NEVER activated in the sample: 1800 / 3072


In [22]:
sae_model_path = f"{log_dir}/sae_{MODEL_NAME}_{LAYER_NAME}.pth"
torch.save(sae_model.state_dict(), sae_model_path)
print(f"SAE model weights saved to: {sae_model_path}")

SAE model weights saved to: runs/gpt2_transformer.h.5.attn.c_proj_sae_training_logs_20250416-160135/sae_gpt2_transformer.h.5.attn.c_proj.pth


# Create report

In [23]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from datetime import datetime
def create_experiment_report():
    report_path = f"{log_dir}/report.pdf"

    with PdfPages(report_path) as pdf:
        plt.figure(figsize=(11, 8))

        # Title Page
        plt.text(0.5, 0.8, "Sparse Autoencoder Experiment Report", 
                ha='center', va='center', fontsize=20)
        plt.text(0.5, 0.6, f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", 
                ha='center', va='center', fontsize=12)
        plt.text(0.5, 0.5, f"Model: {MODEL_NAME} | Layer: {LAYER_NAME}", 
                ha='center', va='center', fontsize=12)
        plt.axis('off')
        pdf.savefig()
        plt.close()
        # Configuration Summary
        plt.figure(figsize=(11, 8))
        config_text = [
            "=== Configuration Summary ===",
            f"\nModel Architecture:",
            f"- Target Layer: {LAYER_NAME}",
            f"- Activation Dim: {ACTIVATION_DIM}",

            f"\nDataset:",
            f"- Name: {DATASET_NAME} ({DATASET_CONFIG})",
            f"- Samples Stored: {NUM_ACTIVATIONS_TO_STORE:,}",
            f"- Max Sequence Length: {MAX_SEQ_LENGTH}",

            f"\nSAE Architecture:",
            f"- Hidden Dim: {SAE_HIDDEN_DIM}",
            f"- Expansion Factor: {SAE_EXPANSION_FACTOR}",

            f"\nTraining:",
            f"- Epochs: {NUM_EPOCHS}",
            f"- Batch Size: {BATCH_SIZE}",
            f"- Learning Rate: {LEARNING_RATE}",
            f"- L1 Coefficient: {L1_COEFF}"
        ]
        plt.text(0.1, 0.9, "\n".join(config_text), fontsize=10, va='top')
        plt.axis('off')
        pdf.savefig()
        plt.close()
        # Training Curves
        plt.figure(figsize=(10, 6))
        plt.plot(training_losses['total'], label='Total Loss')
        plt.plot(training_losses['mse'], label='MSE Loss')
        plt.plot(training_losses['l1'], label='L1 Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss Curves')
        plt.legend()
        plt.grid(True)
        pdf.savefig()
        plt.close()
        # L0 Norm Plot
        plt.figure(figsize=(10, 6))
        plt.plot(training_losses['l0'], label='Avg Active Features', color='purple')
        plt.xlabel('Epoch')
        plt.ylabel('L0 Norm')
        plt.title('Average Active Features per Sample')
        plt.grid(True)
        pdf.savefig()
        plt.close()
        # Feature Analysis
        plt.figure(figsize=(11, 8))
        analysis_text = [
            "=== Feature Analysis ===",
            f"\nTop {NUM_TOP_FEATURES_TO_ANALYZE} Features by Max Activation:",
        ]

        analysis_text.extend([
            f"\nFeature Statistics:",
            f"- Total Features: {SAE_HIDDEN_DIM}",
            f"- Dead Features: {features_never_activated} ({features_never_activated/SAE_HIDDEN_DIM:.1%})",
            f"- Avg Features Active per Sample: {avg_l0_per_input:.1f}"
        ])
        plt.text(0.1, 0.9, "\n".join(analysis_text), fontsize=10, va='top')
        plt.axis('off')
        pdf.savefig()
        plt.close()
        for activations in result:
            plt.figure(figsize=(11, 8))
            feature_text = "==NEW FEATURE==\n"
            for activation in activations:
                feature_text = feature_text + f"Token: {activation['token']}\n"
            plt.text(0.1, 0.9, feature_text, fontsize=10, va='top')
            plt.axis('off')
            pdf.savefig()
            plt.close()
    print(f"Report generated at: {report_path}")
# Generate the report
create_experiment_report()

# Generate the report
create_experiment_report()

Report generated at: runs/gpt2_transformer.h.5.attn.c_proj_sae_training_logs_20250416-160135/report.pdf
Report generated at: runs/gpt2_transformer.h.5.attn.c_proj_sae_training_logs_20250416-160135/report.pdf
