# Configuration

In [1]:
# --- Transformer Configuration ---

MODEL_NAME = "google/gemma-3-1b-it"
REPLACEMENT_LAYER_IDX = 3
LAYER_NAME = f"model.layers.{REPLACEMENT_LAYER_IDX}.mlp.down_proj"
ACTIVATION_DIM = 1152

In [2]:
# --- SAE Configuration ---
SAE_EXPANSION_FACTOR = 8 # How many times larger the SAE hidden dim is than the activation dim
SAE_HIDDEN_DIM = ACTIVATION_DIM * SAE_EXPANSION_FACTOR
L1_COEFF = 3e-4 # Sparsity penalty strength
CHECKPOINT_PATH = "runs/wikitext/B_google_gemma-3-1b-it_model.layers.3.mlp.down_proj_sae_training_logs_20250514-113504"

In [3]:
# -- General Configuration ---
from secret_tokens import access_tokens
token = access_tokens["hf"]

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma3ForCausalLM
from datasets import load_dataset
from tqdm.auto import tqdm  # Use auto version for notebook compatibility
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
import os
import datetime
import math
import copy

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Prepare LLM and SAE

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU: NVIDIA RTX A1000 6GB Laptop GPU


In [5]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=True)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=True)

        self.relu = nn.ReLU()
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x):
        encoded = self.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return encoded, decoded

    def encode(self, x):
      return self.relu(self.encoder(x))

    def decode(self, x):
        return self.decoder(x)

# Initialize SAE
print(f"Initializing SAE with ACTIVATION_DIM={ACTIVATION_DIM}, SAE_HIDDEN_DIM={SAE_HIDDEN_DIM}")
sae_model = SparseAutoencoder(ACTIVATION_DIM, SAE_HIDDEN_DIM).to(device)

checkpoint_full_path = f"{CHECKPOINT_PATH}/sae_{MODEL_NAME.replace('/','_')}_{LAYER_NAME}.pth"
try:
    checkpoint = torch.load(checkpoint_full_path, map_location=device)
    sae_model.load_state_dict(checkpoint["sae_model_state_dict"])
    print(f"Checkpoint loaded successfully from {checkpoint_full_path}")
except FileNotFoundError:
    print(f"ERROR: Checkpoint file not found at {checkpoint_full_path}")
    print("Please ensure the path is correct.")
    # Exit or raise error if file not found, as further steps will fail
    raise
except Exception as e:
    print(f"ERROR: Failed to load checkpoint from {checkpoint_full_path}. Error: {e}")
    raise
print("SAE Model:")
print(sae_model)

Initializing SAE with ACTIVATION_DIM=1152, SAE_HIDDEN_DIM=9216
Checkpoint loaded successfully from runs/wikitext/B_google_gemma-3-1b-it_model.layers.3.mlp.down_proj_sae_training_logs_20250514-113504/sae_google_gemma-3-1b-it_model.layers.3.mlp.down_proj.pth
SAE Model:
SparseAutoencoder(
  (encoder): Linear(in_features=1152, out_features=9216, bias=True)
  (decoder): Linear(in_features=9216, out_features=1152, bias=True)
  (relu): ReLU()
)


In [6]:
# Load act mean
if 'act_mean' in checkpoint:
    loaded_act_mean = checkpoint['act_mean']
    loaded_act_mean = loaded_act_mean.to(device) # Ensure it's on the same device
    print(f"act_mean loaded. Shape: {loaded_act_mean.shape}, Device: {loaded_act_mean.device}")
else:
    raise KeyError("'act_mean' or 'act norms' missing from checkpoint")

act_mean loaded. Shape: torch.Size([1, 1152]), Device: cuda:0


In [7]:
# Wrapper class to allow the patching
class SAEIntervenableMLP(nn.Module):
    def __init__(self, sae_model_instance: SparseAutoencoder, act_mean_global_tensor: torch.Tensor):
        super().__init__()
        self.sae = sae_model_instance
        self.sae.eval()

        if act_mean_global_tensor.ndim == 1:
            act_mean_global_tensor = act_mean_global_tensor.unsqueeze(0)
        if act_mean_global_tensor.shape[0] != 1:
            raise ValueError(f"act_mean_global_tensor should have shape [1, dim] or [dim], got {act_mean_global_tensor.shape}")

        self.register_buffer('act_mean_global', act_mean_global_tensor)
        self.normalization_epsilon = 1e-6 # Same as used in your SAE training data prep

        self.patch_fn = None
        self.patch_kwargs = None

    def set_patch_fn(self, patch_fn=None, **kwargs):
        """
        Set a function to modify the SAE's encoded features.
        The patch_fn should take `encoded_features` as the first argument,
        and any additional `kwargs`.
        Example: def my_patch(features, idx_to_ablate): features[:, idx_to_ablate] = 0; return features
        set_patch_fn(my_patch, idx_to_ablate=10)
        """
        self.patch_fn = patch_fn
        self.patch_kwargs = kwargs if kwargs else {}

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: Input activations (e.g., from the layer norm before the original MLP)
           Shape: (batch_size, seq_len, hidden_dim)
        """
        if x.ndim != 3:
            raise ValueError(f"Input tensor x must be 3-dimensional (batch, seq, dim), got {x.ndim}")
        
        original_shape = x.shape
        hidden_dim = x.shape[-1]
        
        x_flat = x.reshape(-1, hidden_dim)

        # Normalize input in the same way activations were normalized for SAE training
        # (x - mean) / norm
        # self.act_mean_global is (1, hidden_dim)
        x_centered = x_flat - self.act_mean_global
        act_norms = torch.norm(x_centered, dim=1, keepdim=True) + self.normalization_epsilon
        x_normalized = x_centered / act_norms

        encoded_features = self.sae.encode(x_normalized)

        # Apply patching
        if self.patch_fn is not None:
            patched_features = self.patch_fn(encoded_features, **self.patch_kwargs)
        else:
            patched_features = encoded_features
        
        decoded_normalized = self.sae.decode(patched_features)

        # Denormalize: output = reconstructed_normalized * original_norm + original_mean
        # This aims to restore the scale and shift of the input `x` to the SAEIntervenableMLP.
        # Note: The SAE was trained to reconstruct normalized *MLP outputs*. Here we are
        # (de)normalizing based on the statistics of *MLP inputs*. This is an approximation.
        reconstructed_output_flat = decoded_normalized * act_norms + self.act_mean_global
        
        # Reshape to original input shape
        output = reconstructed_output_flat.reshape(original_shape)
        
        return output

In [8]:
# Load Model
print(f"Loading model: {MODEL_NAME}")
#model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device) # for gpt2
model = Gemma3ForCausalLM.from_pretrained(MODEL_NAME, token=token) # for gemma
model.eval()

print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Loading model: google/gemma-3-1b-it
Loading tokenizer: google/gemma-3-1b-it


In [9]:
def replace_mlp_layer_with_sae_mlp(gemma_model: Gemma3ForCausalLM,
                                   layer_idx: int,
                                   sae_inter_mlp_instance: SAEIntervenableMLP) -> Gemma3ForCausalLM: # Or GemmaForCausalLM
    """
    Replaces the MLP block in a specific layer of the Gemma model with the SAEIntervenableMLP.
    Returns a deep copy of the model with the layer replaced.
    """
    model_copy = copy.deepcopy(gemma_model)
    model_copy.eval()

    if not (0 <= layer_idx < len(model_copy.model.layers)):
        raise ValueError(f"Layer index {layer_idx} is out of bounds for model with {len(model_copy.model.layers)} layers.")

    # Insert the SAE after the down proj
    original_mlp_down_proj = model_copy.model.layers[layer_idx].mlp.down_proj
    
    # Ensure the SAEIntervenableMLP is on the same device as the model
    sae_inter_mlp_instance = sae_inter_mlp_instance.to(next(model_copy.parameters()).device)
    
    model_copy.model.layers[layer_idx].mlp.down_proj = nn.Sequential(
        original_mlp_down_proj,
        sae_inter_mlp_instance,
    ).to(device)
    
    print(f"Replaced MLP in layer {layer_idx} of Gemma model with SAEIntervenableMLP.")
    print(f"  Original MLP: {gemma_model.model.layers[layer_idx].mlp}")
    print(f"  New MLP: {model_copy.model.layers[layer_idx].mlp}")
    return model_copy

# Patching

In [10]:
sae_inter_mlp = SAEIntervenableMLP(sae_model, loaded_act_mean)

In [11]:
model_with_sae_mlp = replace_mlp_layer_with_sae_mlp(model, REPLACEMENT_LAYER_IDX, sae_inter_mlp)
model_with_sae_mlp.to(device).eval()

Replaced MLP in layer 3 of Gemma model with SAEIntervenableMLP.
  Original MLP: Gemma3MLP(
  (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
  (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
  (down_proj): Linear(in_features=6912, out_features=1152, bias=False)
  (act_fn): PytorchGELUTanh()
)
  New MLP: Gemma3MLP(
  (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
  (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
  (down_proj): Sequential(
    (0): Linear(in_features=6912, out_features=1152, bias=False)
    (1): SAEIntervenableMLP(
      (sae): SparseAutoencoder(
        (encoder): Linear(in_features=1152, out_features=9216, bias=True)
        (decoder): Linear(in_features=9216, out_features=1152, bias=True)
        (relu): ReLU()
      )
    )
  )
  (act_fn): PytorchGELUTanh()
)


Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
    (layers): ModuleList(
      (0-2): 3 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1152, out_features=256, bias=False)
          (v_proj): Linear(in_features=1152, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (down_proj): Linear(in_features=6912, out_features=1152, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma3RMSNorm((1152

In [12]:
# Helper function to generate text
def generate_text(model_to_use, tokenizer_to_use, prompt, max_new_tokens=50):
    model_to_use.eval() # Ensure model is in eval mode
    inputs = tokenizer_to_use(prompt, return_tensors="pt", padding=True, truncation=True).to(model_to_use.device)
    
    with torch.no_grad():
        outputs = model_to_use.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            #eos_token_id=tokenizer_to_use.eos_token_id, # Optional: helps stop generation
            #pad_token_id=tokenizer_to_use.pad_token_id, # Ensure this is set
            do_sample=False # For more comparable deterministic output
        )
    
    generated_text = tokenizer_to_use.decode(outputs[0], skip_special_tokens=True)
    return generated_text

In [13]:
prompt = "I walked for 3 km straight. Last year, I run a marathon that was long "

print("--- Generating with Original Gemma Model ---")
original_output = generate_text(model, tokenizer, prompt)
print(f"Prompt: {prompt}")
print(f"Original Gemma Output: {original_output}\n")

print(f"--- Generating with Gemma + SAE-MLP (Layer {REPLACEMENT_LAYER_IDX}, No Patching) ---")
sae_mlp_no_patch_output = generate_text(model_with_sae_mlp, tokenizer, prompt)
print(f"Prompt: {prompt}")
print(f"Gemma + SAE-MLP (No Patch) Output: {sae_mlp_no_patch_output}\n")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


--- Generating with Original Gemma Model ---




Prompt: I walked for 3 km straight. Last year, I run a marathon that was long 
Original Gemma Output: I walked for 3 km straight. Last year, I run a marathon that was long 42.195 km.

The question asks for the distance of the walk.

The answer is 3 km.

Final Answer: The final answer is $\boxed{3}$

--- Generating with Gemma + SAE-MLP (Layer 3, No Patching) ---
Prompt: I walked for 3 km straight. Last year, I run a marathon that was long 
Gemma + SAE-MLP (No Patch) Output: I walked for 3 km straight. Last year, I run a marathon that was long 42.195 km.

The question asks for the distance of the walk.

The answer is 3 km.

Final Answer: The final answer is $\boxed{3}$



In [14]:
def ablate_features_patch(encoded_features, feature_idx):
    """Sets specified feature activations to zero."""
    patched = encoded_features.clone()
    if not isinstance(feature_idx, list):
        feature_idx = [feature_idx]
    for idx in feature_idx:
        patched[:, idx] = 0.0
    return patched

def amplify_feature_patch(encoded_features, feature_idx, scale_factor):
    """Amplifies a specified feature activation."""
    patched = encoded_features.clone()
    patched[:, feature_idx] *= scale_factor
    return patched

def set_feature_value_patch(encoded_features, feature_idx, value):
    """Sets a specified feature activation to a fixed value."""
    patched = encoded_features.clone()
    patched[:, feature_idx] = value
    return patched

In [15]:
sae_mlp_layer_in_model = model_with_sae_mlp.model.layers[REPLACEMENT_LAYER_IDX].mlp.down_proj[1]

feature_to_patch = 2715
sae_mlp_layer_in_model.set_patch_fn(set_feature_value_patch, feature_idx=feature_to_patch, value=0.5)

print(f"--- Generating with Gemma + SAE-MLP (Layer {REPLACEMENT_LAYER_IDX}, Ablating Feature {feature_to_patch}) ---")
sae_mlp_ablation_output = generate_text(model_with_sae_mlp, tokenizer, prompt)
print(f"Prompt: {prompt}")
print(f"Gemma + SAE-MLP (Ablating F{feature_to_patch}) Output: {sae_mlp_ablation_output}\n")

--- Generating with Gemma + SAE-MLP (Layer 3, Ablating Feature 2715) ---
Prompt: I walked for 3 km straight. Last year, I run a marathon that was long 
Gemma + SAE-MLP (Ablating F2715) Output: I walked for 3 km straight. Last year, I run a marathon that was long 10 years.
I'm a small business, and I'm struggling to make a profit.
I'm feeling a bit lost.

I've been working on a new marketing strategy.
I'm trying to build



In [16]:
# Clean model after patching
with torch.no_grad():
    torch.cuda.empty_cache()