# Configuration

In [None]:
# --- Transformer Configuration ---

# GPT-2
# MODEL_NAME = "gpt2"
# We target the output projection of the MLP layer in a specific block
# LAYER_IDX = 1
# LAYER_NAME = f"transformer.h.{LAYER_IDX}.mlp.c_proj" # Output projection of MLP
# LAYER_NAME = f"transformer.h.{LAYER_IDX}.attn.c_proj" # Output projection of Attention
# Alternative: f"transformer.h.{LAYER_IDX}" # Output of the whole block (residual stream)
# ACTIVATION_DIM = 768

MODEL_NAME = "google/gemma-3-1b-it"
LAYER_IDX = 11
LAYER_NAME = f"model.layers.{LAYER_IDX}.mlp.down_proj"
ACTIVATION_DIM = 1152

In [None]:
# --- Dataset Configuration ---
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-103-raw-v1" # Or "wikitext-2-raw-v1" for smaller
DATASET_SPLIT = "train"
NUM_ACTIVATIONS_TO_STORE = 1600000 # Max activations to collect (adjust based on RAM)
MAX_SEQ_LENGTH = 128 # Max token length for processing sequences

In [None]:
# --- SAE Configuration ---
SAE_EXPANSION_FACTOR = 4 # How many times larger the SAE hidden dim is than the activation dim
SAE_HIDDEN_DIM = ACTIVATION_DIM * SAE_EXPANSION_FACTOR
L1_COEFF = 8e-3 # Sparsity penalty strength

In [None]:
# --- Training Configuration ---
BATCH_SIZE = 512
LEARNING_RATE = 3e-4
NUM_EPOCHS = 300
PRINT_INTERVAL = 100

In [None]:
# --- Analysis Configuration ---
NUM_TOP_FEATURES_TO_ANALYZE = 100
ACTIVITY_THRESHOLD = 1e-6 # Threshold to consider a feature non-dead

In [None]:
# -- Import tokens ---
from secret_tokens import access_tokens
token = access_tokens["hf"]

# Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma3ForCausalLM
from datasets import load_dataset
from tqdm.auto import tqdm  # Use auto version for notebook compatibility
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
import os
import datetime
import math

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

In [None]:
print(f"Loading model: {MODEL_NAME}")
#model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device) # for gpt2
model = Gemma3ForCausalLM.from_pretrained(MODEL_NAME, token=token).to(device) # for gemma
model.eval()

print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Data

In [None]:
activation_storage = []

# Register hook to save the activation
def get_activation_hook_no_flatten(layer_name):
    """Creates a hook function to capture activations."""
    def hook(model, input, output):
        act_tensor = output[0] if isinstance(output, tuple) else output
        activation_storage.append(act_tensor.detach().cpu().to(torch.float32))

    target_module = dict(model.named_modules())[layer_name]
    handle = target_module.register_forward_hook(hook)
    return handle

# Register the non-flattening hook
print(f"Registering hook on layer: {LAYER_NAME}")
hook_handle = get_activation_hook_no_flatten(LAYER_NAME)

activation_storage_filtered = []
token_storage = []

print(f"Loading dataset: {DATASET_NAME} ({DATASET_CONFIG})")
dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=f"{DATASET_SPLIT}", streaming=True)

print("Processing dataset and extracting activations (Aligning Tokens and Activations)...")
num_processed_tokens = 0 # Track actual non-padding tokens processed
pbar = tqdm(total=NUM_ACTIVATIONS_TO_STORE, desc="Extracting Aligned Activations/Tokens")

for example in iter(dataset):
    if num_processed_tokens >= NUM_ACTIVATIONS_TO_STORE:
        break

    text = example['text']
    if not text or len(text.strip()) == 0:
      continue

    # Tokenize the text
    inputs = tokenizer(text, return_tensors="pt", max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device) # Shape: (batch_size, seq_len)
    
    activation_storage.clear()

    with torch.no_grad():
        model(input_ids=input_ids, attention_mask=attention_mask)

    # --- Retrieve activation (assuming hook appended to global list) ---
    if not activation_storage:
         print("Warning: Hook did not capture activation for a batch. Skipping.")
         continue
        
    batch_activations = activation_storage[0]

    # Ensure shapes match for filtering
    if batch_activations.shape[0] != input_ids.shape[0] or batch_activations.shape[1] != input_ids.shape[1]:
       print(f"Warning: Activation shape {batch_activations.shape} mismatch with input_ids {input_ids.shape}. Check hook logic. Skipping batch.")
       continue

    # --- Filter Activations and Tokens based on Attention Mask ---
    cpu_input_ids = input_ids.cpu().numpy() # (batch_size, seq_len)
    cpu_attention_mask = attention_mask.cpu().numpy().astype(bool) # (batch_size, seq_len)

    # Flatten batch and sequence dimensions for filtering
    flat_activations = batch_activations.view(-1, ACTIVATION_DIM) # (batch*seq, hidden)
    flat_input_ids = cpu_input_ids.flatten() # (batch*seq,)
    flat_attention_mask = cpu_attention_mask.flatten() # (batch*seq,)

    # Select only items where attention_mask is True
    filtered_activations = flat_activations[flat_attention_mask]
    filtered_tokens = flat_input_ids[flat_attention_mask]

    # Store filtered tokens
    if filtered_activations.shape[0] > 0: # Only store if there are non-padding tokens
        activation_storage_filtered.append(filtered_activations)
        newly_added_tokens = filtered_tokens.tolist()
        token_storage.extend(newly_added_tokens)

        # --- Update Progress ---
        num_processed_tokens += len(newly_added_tokens)
        pbar.update(len(newly_added_tokens))

pbar.close()
hook_handle.remove() # Remove the hook

if not activation_storage_filtered:
    raise ValueError("No activations were extracted/stored after filtering. Check data or hook.")

all_activations = torch.cat(activation_storage_filtered, dim=0)
print(f"Total ALIGNED activations extracted: {all_activations.shape[0]}")
print(f"Total ALIGNED tokens stored: {len(token_storage)}")
print(f"Activation tensor final shape: {all_activations.shape}") # Should be (N, ACTIVATION_DIM)
print(f"Token storage final length: {len(token_storage)}")      # N should match above

# Check if the lengths match (they should now)
if all_activations.shape[0] != len(token_storage):
    print(f"CRITICAL WARNING: Mismatch between activation count ({all_activations.shape[0]}) and token count ({len(token_storage)}) AFTER filtering!")
else:
    print("Activation and token counts match. Alignment successful.")

# Normalize activations for SAE training
act_mean = torch.mean(all_activations, dim=0, keepdim=True)
act_norms = torch.norm(all_activations, dim=1, keepdim=True) + 1e-6 # Add epsilon for stability
all_activations = (all_activations - act_mean) / act_norms
print("Filtered activations normalized (unit norm per vector).")

# Convert to PyTorch DataLoader for training
activation_dataset = torch.utils.data.TensorDataset(all_activations)
activation_loader = torch.utils.data.DataLoader(activation_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Clean up memory
del activation_storage # Delete the temporary batch storage list
del activation_storage_filtered
del activation_dataset
del filtered_tokens
del filtered_activations
del model

In [None]:
import gc
gc.collect()
if device == torch.device("cuda"):
    torch.cuda.empty_cache()

# Training

In [None]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=True)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=True)

        self.relu = nn.ReLU()
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x):
        encoded = self.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return encoded, decoded

    def encode(self, x):
      return self.relu(self.encoder(x))

# Initialize SAE
print(f"Initializing SAE with ACTIVATION_DIM={ACTIVATION_DIM}, SAE_HIDDEN_DIM={SAE_HIDDEN_DIM}")
sae_model = SparseAutoencoder(ACTIVATION_DIM, SAE_HIDDEN_DIM).to(device)
print("SAE Model:")
print(sae_model)

In [None]:
def sae_loss(original_activations, encoded_activations, decoded_activations, l1_coeff):
    """Calculates SAE loss: MSE + L1 penalty on encoded activations."""
    mse_loss = nn.functional.mse_loss(decoded_activations, original_activations)

    l1_loss = torch.mean(torch.abs(encoded_activations))

    total_loss = mse_loss + l1_coeff * l1_loss
    return total_loss, mse_loss, l1_loss

optimizer = optim.Adam(sae_model.parameters(), lr=LEARNING_RATE)

In [None]:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = f"runs/{MODEL_NAME.replace('/', '_')}_{LAYER_NAME}_sae_training_logs_{timestamp}" # You can customize this path
os.makedirs(log_dir, exist_ok=True) # Create the directory if it doesn't exist
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print("Run 'tensorboard --logdir runs' (or your specific log_dir parent) to view logs.")

In [None]:
print("Starting SAE training...")
sae_model.train()
training_losses = {'total': [], 'mse': [], 'l1': [], 'l0': []} # Track losses and L0

for epoch in range(NUM_EPOCHS):
    epoch_total_loss = 0
    epoch_mse_loss = 0
    epoch_l1_loss = 0
    epoch_l0_norm = 0 # Track average number of non-zero features

    pbar_train = tqdm(activation_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
    for i, batch in enumerate(pbar_train):
        # Batch is a list containing one tensor: the activations
        original_acts = batch[0].to(device)

        encoded, decoded = sae_model(original_acts)
        loss, mse, l1 = sae_loss(original_acts, encoded, decoded, L1_COEFF)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        epoch_total_loss += loss.item()
        epoch_mse_loss += mse.item()
        epoch_l1_loss += l1.item()
        
        l0 = torch.sum(encoded > 1e-6, dim=1).float().mean().item()
        epoch_l0_norm += l0

        if (i + 1) % PRINT_INTERVAL == 0 or i == len(activation_loader) - 1:
             pbar_train.set_postfix({
                 "Total Loss": f"{loss.item():.8f}",
                 "MSE": f"{mse.item():.8f}",
                 "L1": f"{l1.item():.8f}",
                 "L0": f"{l0:.2f}" # Avg active features in current batch
             })

    # Average losses for the epoch
    avg_total_loss = epoch_total_loss / len(activation_loader)
    avg_mse_loss = epoch_mse_loss / len(activation_loader)
    avg_l1_loss = epoch_l1_loss / len(activation_loader)
    avg_l0_norm = epoch_l0_norm / len(activation_loader)

    training_losses['total'].append(avg_total_loss)
    training_losses['mse'].append(avg_mse_loss)
    training_losses['l1'].append(avg_l1_loss)
    training_losses['l0'].append(avg_l0_norm)

    writer.add_scalar('Loss/Epoch/Total', avg_total_loss, epoch)
    writer.add_scalar('Loss/Epoch/MSE', avg_mse_loss, epoch)
    writer.add_scalar('Loss/Epoch/L1', avg_l1_loss, epoch)
    writer.add_scalar('L0_Norm/Epoch', avg_l0_norm, epoch)


    print(f"Epoch {epoch+1} Complete: Avg Total Loss={avg_total_loss:.12f}, Avg MSE={avg_mse_loss:.12f}, Avg L1={avg_l1_loss:.7f}, Avg L0={avg_l0_norm:.2f}")

print("SAE Training finished.")

In [None]:
sae_model_path = f"{log_dir}/sae_{MODEL_NAME.replace('/','_')}_{LAYER_NAME}.pth"
torch.save(sae_model.state_dict(), sae_model_path)
print(f"SAE model weights saved to: {sae_model_path}")

In [None]:
# Load model if you already have a checkpoint
log_dir = "runs/google_gemma-3-1b-it_model.layers.11.mlp.down_proj_sae_training_logs_20250424-174247"
model_name = f"{log_dir}/sae_{MODEL_NAME.replace('/','_')}_{LAYER_NAME}.pth"
sae_model.load_state_dict(torch.load(model_name, map_location=device))

# Evaluation

In [None]:
sae_model.eval()

print("Recomputing SAE features for analysis...")
with torch.no_grad():
    all_sae_features = []
    analysis_batch_size = 1024
    for i in range(0, all_activations.shape[0], analysis_batch_size):
        batch_acts = all_activations[i:i+analysis_batch_size].to(device)
        batch_features = sae_model.encode(batch_acts).half().cpu()
        all_sae_features.append(batch_features)

    all_sae_features = torch.cat(all_sae_features, dim=0)

print(f"SAE features shape: {all_sae_features.shape}")

In [None]:
def get_max_activating_examples(feature_index, num_examples=10, window_size=10):
    """Find text snippets where a specific SAE feature maximally activates."""
    if feature_index < 0 or feature_index >= SAE_HIDDEN_DIM:
        print(f"Error: Feature index {feature_index} out of bounds (0-{SAE_HIDDEN_DIM-1}).")
        return
    if all_sae_features.shape[0] == 0:
        print("Error: No SAE features computed.")
        return
    if not token_storage:
        print("Error: Token storage is empty.")
        return

    print(f"\n--- Finding top activating examples for Feature {feature_index} ---")

    feature_activations = all_sae_features[:, feature_index]

    top_k_values, top_k_indices = torch.topk(feature_activations, k=min(num_examples, all_sae_features.shape[0]), largest=True) # Handle cases with fewer activations than requested

    analysis_output = []
    
    for rank, (idx_tensor, activation_value_tensor) in enumerate(zip(top_k_indices, top_k_values)):
        token_dict = {}
        activation_value = activation_value_tensor.item()
        global_token_idx = idx_tensor.item()

        if global_token_idx < 0 or global_token_idx >= len(token_storage):
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error - Calculated token index {global_token_idx} is out of bounds for token_storage (len={len(token_storage)}). Skipping.")
             continue
        
        start = max(0, global_token_idx - window_size)
        end = min(len(token_storage), global_token_idx + window_size + 1)

        if start >= end:
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error - Invalid window slice indices (start={start}, end={end}). Skipping.")
             continue

        token_ids_window = token_storage[start:end]

        try:
             context_text = tokenizer.decode(token_ids_window, skip_special_tokens=False, errors='replace')

             activating_token_id = token_storage[global_token_idx]
             activating_token_str = tokenizer.decode([activating_token_id], skip_special_tokens=False, errors='replace')

             token_dict['token'] = activating_token_str
             token_dict['context'] = context_text
             token_dict['activation'] = activation_value

        except IndexError:
             print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
             print(f"Context: Error decoding tokens - IndexError occurred with indices ({start}, {end}) for token_storage (len={len(token_storage)})")
        except Exception as e:
            print(f"\nRank {rank+1}: Activation = {activation_value:.4f}")
            print(f"Context: Error decoding tokens: {token_ids_window} - {e}")
        analysis_output.append(token_dict)
    return analysis_output

In [None]:
if 'all_sae_features' not in locals() or all_sae_features.shape[0] == 0:
    print("ERROR: `all_sae_features` tensor not found or is empty. Please run the previous cell first.")
else:
    print(f"Calculating statistics across {all_sae_features.shape[1]} features...")

    # Calculate statistics per feature
    try:
        feature_max_activations, _ = torch.max(all_sae_features, dim=0)

        # feature_mean_activations = torch.mean(all_sae_features, dim=0)

        # feature_std_devs = torch.std(all_sae_features, dim=0)

        active_feature_mask = feature_max_activations > ACTIVITY_THRESHOLD # FILTER!!
        active_feature_indices = torch.where(active_feature_mask)[0]

        print(f"Found {len(active_feature_indices)} active features (max > {ACTIVITY_THRESHOLD}).")

        top_max_k_values, top_max_k_indices = torch.topk(
            feature_max_activations,
            k=min(NUM_TOP_FEATURES_TO_ANALYZE, all_sae_features.shape[1])
        )
        top_max_activation_indices = top_max_k_indices.tolist()
        print(f"\nTop {len(top_max_activation_indices)} features by MAX activation selected.")

        del feature_max_activations, # feature_mean_activations, feature_std_devs
        del active_feature_mask, active_feature_indices
        gc.collect()


    except Exception as e:
        print(f"An error occurred during statistics calculation: {e}")
        import traceback
        traceback.print_exc()
        top_max_activation_indices = []
        top_std_dev_indices = []
        top_mean_activation_indices = []

In [None]:
# Option 1: Maximum Activation (Recommended for seeing peak usage)
feature_strengths_np = torch.max(all_sae_features, dim=0)[0].numpy()
metric_label = "Max Activation"

# # Option 2: Mean Activation
# feature_strengths = torch.mean(all_sae_features, dim=0)
# metric_label = "Mean Activation"

# # Option 3: Activation Frequency (Approx L0 Norm)
# feature_activity = (all_sae_features > ACTIVITY_THRESHOLD).float() # ACTIVITY_THRESHOLD from config
# feature_strengths = torch.mean(feature_activity, dim=0)
# metric_label = f"Activation Frequency (>{ACTIVITY_THRESHOLD})"
# del feature_activity # Free up memory

print(f"Calculated feature strengths using: {metric_label}")
num_features = SAE_HIDDEN_DIM

def get_grid_dims(n):
    """Finds reasonably square grid dimensions for n items."""
    if n <= 0: return (0,0)
    sqrt_n = math.isqrt(n)
    rows = sqrt_n
    while n % rows != 0:
        rows -= 1
        if rows == 0:
             rows = int(sqrt_n)
             while n % rows != 0:
                 rows +=1
             break

    cols = n // rows
    return (min(rows, cols), max(rows, cols)) if rows > 0 else (1,n)

# 3072 = 64*48
if 64 * 48 == num_features:
     grid_rows, grid_cols = 48, 64
     print(f"Using specified grid: {grid_rows}x{grid_cols}")
else:
     grid_rows, grid_cols = get_grid_dims(num_features)
     if grid_rows * grid_cols > num_features:
          print(f"Calculated grid: {grid_rows}x{grid_cols}. Padding needed.")
          padding_size = grid_rows * grid_cols - num_features
          feature_strengths_np = np.pad(feature_strengths_np, (0, padding_size), mode='constant', constant_values=np.nan)
     else:
          print(f"Calculated grid: {grid_rows}x{grid_cols}.")


feature_grid = feature_strengths_np.reshape(grid_rows, grid_cols)

# Plotting

plt.figure(figsize=(12, 10))
plt.imshow(feature_grid, cmap='viridis', aspect='auto') # Alternative without seaborn
plt.colorbar(label=metric_label)
plt.title(f'SAE Feature Strength ({metric_label}) - {grid_rows}x{grid_cols} Grid')
plt.xlabel("Feature Index (Column Offset)")
plt.ylabel("Feature Index (Row Offset)")
plt.tight_layout()
plt.show()

try:
    from mpl_toolkits.mplot3d import Axes3D

    if num_features > 5000:
         print("\nSkipping 3D bar plot: Too many features (>5000), likely unreadable.")
    else:
         print("\nGenerating 3D bar plot (can be slow and cluttered)...")
         fig = plt.figure(figsize=(15, 10))
         ax = fig.add_subplot(111, projection='3d')

         xpos, ypos = np.meshgrid(np.arange(grid_cols), np.arange(grid_rows))
         xpos = xpos.flatten()
         ypos = ypos.flatten()
         zpos = np.zeros_like(xpos)

         dz = feature_grid.flatten()
         valid_bars = ~np.isnan(dz)

         colors = plt.cm.viridis(dz[valid_bars] / np.nanmax(dz))

         ax.bar3d(xpos[valid_bars], ypos[valid_bars], zpos[valid_bars],
                  dx=0.8, dy=0.8, dz=dz[valid_bars],
                  color=colors, # Use variable colors
                  shade=True)

         ax.set_xlabel('Feature Col Index')
         ax.set_ylabel('Feature Row Index')
         ax.set_zlabel(metric_label)
         ax.set_title(f'SAE Feature Strength ({metric_label}) - 3D Bars')
         plt.tight_layout()
         plt.show()

except ImportError:
    print("\nCould not import Axes3D for 3D plot. Skipping.")
except Exception as e:
    print(f"\nAn error occurred during 3D plot generation: {e}")

In [None]:
# Show token and context
result = []
for feature_idx in top_max_activation_indices:
    c = get_max_activating_examples(feature_index=feature_idx, num_examples=10, window_size=10)
    result.append(c)
    for i in c:
        print(f"\n{i['context']}")
        print(i['token'])

# Debug

In [None]:
print("DEBUG: Inspecting sample SAE features...")
sae_model.eval()
sae_model.to(device) # Ensure model is on GPU

# Use a subset of activations to avoid RAM issues if all_activations is huge
num_inspect_samples = min(5000, all_activations.shape[0]) # Inspect up to 5k samples
sample_indices = torch.randperm(all_activations.shape[0])[:num_inspect_samples]

# Keep activations on CPU for this inspection if they fit, otherwise batch
sample_acts_cpu = all_activations[sample_indices]

batch_size_inspect = 1024
all_sample_features_list = []
with torch.no_grad():
    for i in range(0, num_inspect_samples, batch_size_inspect):
        batch_acts_gpu = sample_acts_cpu[i:i+batch_size_inspect].to(device)
        batch_features_gpu = sae_model.encode(batch_acts_gpu)
        all_sample_features_list.append(batch_features_gpu.cpu()) # Move features to CPU

sample_features = torch.cat(all_sample_features_list, dim=0)

print(f"Sample features shape: {sample_features.shape}") # Should be (num_inspect_samples, SAE_HIDDEN_DIM)
print(f"Sample features min: {sample_features.min():.6f}")
print(f"Sample features max: {sample_features.max():.6f}")
print(f"Sample features mean: {sample_features.mean():.6f}")
print(f"Sample features std: {sample_features.std():.6f}")

# Calculate L0 norm for the sample (more accurate than just min/max/mean)
# Use a small threshold to account for floating point inaccuracies
active_features = (sample_features > 1e-6).float()
avg_l0_per_input = active_features.sum(dim=1).mean().item()
total_active_features = active_features.sum().item()
fraction_active = total_active_features / sample_features.numel()

print(f"\nSample L0 norm (avg active features per input): {avg_l0_per_input:.4f}")
print(f"Total active feature values in sample: {total_active_features}")
print(f"Fraction of active feature values: {fraction_active:.6f}")

# Check proportion of DEAD features (never activate across the sample)
features_never_activated = (active_features.sum(dim=0) == 0).sum().item()
print(f"\nNumber of features that NEVER activated in the sample: {features_never_activated} / {SAE_HIDDEN_DIM}")

del sample_acts_cpu, all_sample_features_list, sample_features, active_features # Cleanup
gc.collect()
if device == torch.device("cuda"):
    torch.cuda.empty_cache()

# Create report

In [None]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from datetime import datetime

def create_experiment_report():
    report_path = f"{log_dir}/report.pdf"

    with PdfPages(report_path) as pdf:
        # Title Page
        plt.figure(figsize=(11, 8))
        plt.text(0.5, 0.8, "Sparse Autoencoder Experiment Report", 
                 ha='center', va='center', fontsize=20)
        plt.text(0.5, 0.6, f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", 
                 ha='center', va='center', fontsize=12)
        plt.text(0.5, 0.5, f"Model: {MODEL_NAME} | Layer: {LAYER_NAME}", 
                 ha='center', va='center', fontsize=12)
        plt.axis('off')
        pdf.savefig()
        plt.close()

        # Configuration Summary
        plt.figure(figsize=(11, 8))
        config_text = [
            "=== Configuration Summary ===",
            f"\nModel Architecture:",
            f"- Target Layer: {LAYER_NAME}",
            f"- Activation Dim: {ACTIVATION_DIM}",
            f"\nDataset:",
            f"- Name: {DATASET_NAME} ({DATASET_CONFIG})",
            f"- Samples Stored: {NUM_ACTIVATIONS_TO_STORE:,}",
            f"- Max Sequence Length: {MAX_SEQ_LENGTH}",
            f"\nSAE Architecture:",
            f"- Hidden Dim: {SAE_HIDDEN_DIM}",
            f"- Expansion Factor: {SAE_EXPANSION_FACTOR}",
            f"\nTraining:",
            f"- Epochs: {NUM_EPOCHS}",
            f"- Batch Size: {BATCH_SIZE}",
            f"- Learning Rate: {LEARNING_RATE}",
            f"- L1 Coefficient: {L1_COEFF}"
        ]
        plt.text(0.1, 0.9, "\n".join(config_text), fontsize=10, va='top')
        plt.axis('off')
        pdf.savefig()
        plt.close()

        # Training Curves
        plt.figure(figsize=(10, 6))
        plt.plot(training_losses['total'], label='Total Loss')
        plt.plot(training_losses['mse'], label='MSE Loss')
        plt.plot(training_losses['l1'], label='L1 Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss Curves')
        plt.legend()
        plt.grid(True)
        pdf.savefig()
        plt.close()
        

        # L0 Norm Plot
        plt.figure(figsize=(10, 6))
        plt.plot(training_losses['l0'], label='Avg Active Features', color='purple')
        plt.xlabel('Epoch')
        plt.ylabel('L0 Norm')
        plt.title('Average Active Features per Sample')
        plt.grid(True)
        pdf.savefig()
        plt.close()
        

        # 3D Feature Strength Plot
        if Axes3D is None or num_features > 5000:
            # Skip if too many features or missing toolkit
            plt.figure(figsize=(11, 8))
            skip_msg = ("3D bar plot skipped: " +
                        ("Too many features (>5000)." if num_features > 5000 else
                         "mpl_toolkits.mplot3d not available."))
            plt.text(0.5, 0.5, skip_msg, ha='center', va='center', fontsize=12)
            plt.axis('off')
            pdf.savefig()
            plt.close()
        else:
            fig = plt.figure(figsize=(15, 10))
            ax = fig.add_subplot(111, projection='3d')

            xpos, ypos = np.meshgrid(np.arange(grid_cols), np.arange(grid_rows))
            xpos = xpos.flatten()
            ypos = ypos.flatten()
            zpos = np.zeros_like(xpos)

            dz = feature_grid.flatten()
            valid = ~np.isnan(dz)
            colors = plt.cm.viridis(dz[valid] / np.nanmax(dz))

            ax.bar3d(xpos[valid], ypos[valid], zpos[valid],
                     dx=0.8, dy=0.8, dz=dz[valid],
                     color=colors, shade=True)

            ax.set_xlabel('Feature Col Index')
            ax.set_ylabel('Feature Row Index')
            ax.set_zlabel(metric_label)
            ax.set_title(f'SAE Feature Strength ({metric_label}) - 3D Bars')
            plt.tight_layout()
            pdf.savefig()
            plt.close()

        # Feature Analysis
        plt.figure(figsize=(11, 8))
        analysis_text = [
            "=== Feature Analysis ===",
            f"\nTop {NUM_TOP_FEATURES_TO_ANALYZE} Features by Max Activation:",
            f"\nFeature Statistics:",
            f"- Total Features: {SAE_HIDDEN_DIM}",
            f"- Dead Features: {features_never_activated} ({features_never_activated/SAE_HIDDEN_DIM:.1%})",
            f"- Avg Features Active per Sample: {avg_l0_per_input:.1f}"
        ]
        plt.text(0.1, 0.9, "\n".join(analysis_text), fontsize=10, va='top')
        plt.axis('off')
        pdf.savefig()
        plt.close()

        # Token-level Feature Pages
        for activations in result:
            plt.figure(figsize=(11, 8))
            feature_text = "==NEW FEATURE==\n"
            for act in activations:
                feature_text += f"Token: {act['token']}\nContext: {act['context']}\n--------------------\n"
            plt.text(0.1, 0.9, feature_text, fontsize=7, va='top')
            plt.axis('off')
            pdf.savefig()
            plt.close()

    print(f"Report generated at: {report_path}")

# Generate the report
create_experiment_report()