# Steering models with target features

In [3]:
import os
import sys
import json
import torch
import numpy as np
from typing import List, Dict, Optional, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from dataclasses import dataclass
from sae_lens import SAE

sys.path.append('.')
sys.path.append('..')

from utils.steering_utils import ActivationSteering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [9]:
@dataclass
class ModelConfig:
    """Configuration for model-specific settings"""
    base_model_name: str
    chat_model_name: str
    hf_release: str  # Reference only - actual loading uses saelens_release/sae_id
    assistant_header: str
    token_offsets: Dict[str, int]
    sae_base_path: str
    saelens_release: str  # Template for sae_lens release parameter
    sae_id_template: str  # Template for sae_lens sae_id parameter
    base_url: str  # Base URL for neuronpedia
    
    def get_sae_params(self, sae_layer: int, sae_trainer: str) -> Tuple[str, str]:
        """
        Generate SAE lens release and sae_id parameters.
        
        Args:
            sae_layer: Layer number for the SAE
            sae_trainer: Trainer identifier for the SAE
            
        Returns:
            Tuple of (release, sae_id) for sae_lens.SAE.from_pretrained()
        """
        if self.saelens_release == "llama_scope_lxr_{trainer}":
            release = self.saelens_release.format(trainer=sae_trainer)
            sae_id = self.sae_id_template.format(layer=sae_layer, trainer=sae_trainer)
        elif self.saelens_release == "gemma-scope-9b-pt-res":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"
            l0_value = parts[2]  # "34"
            
            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width, l0=l0_value)
        elif self.saelens_release == "gemma-scope-9b-pt-res-canonical":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"

            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width)
        else:
            raise ValueError(f"Unknown SAE lens release template: {self.saelens_release}")
        
        return release, sae_id

# Model configurations
MODEL_CONFIGS = {
    "llama": ModelConfig(
        base_model_name="meta-llama/Llama-3.1-8B",
        chat_model_name="meta-llama/Llama-3.1-8B-Instruct",
        hf_release="fnlp/Llama3_1-8B-Base-LXR-32x",
        assistant_header="<|start_header_id|>assistant<|end_header_id|>",
        token_offsets={"asst": -2, "endheader": -1, "newline": 0},
        sae_base_path="/workspace/sae/llama-3.1-8b/saes",
        saelens_release="llama_scope_lxr_{trainer}",
        sae_id_template="l{layer}r_{trainer}",
        base_url="https://www.neuronpedia.org/llama-3.1-8b/{layer}-llamascope-res-131k"
    ),
    "gemma": ModelConfig(
        base_model_name="google/gemma-2-9b",
        chat_model_name="google/gemma-2-9b-it",
        hf_release="google/gemma-scope-9b-pt-res/layer_{layer}/width_{width}/average_l0_{l0}",
        assistant_header="<start_of_turn>model",
        token_offsets={"model": -1, "newline": 0},
        sae_base_path="/workspace/sae/gemma-2-9b/saes",
        saelens_release="gemma-scope-9b-pt-res-canonical",
        sae_id_template="layer_{layer}/width_{width}/canonical",
        base_url="https://www.neuronpedia.org/gemma-2-9b/{layer}-gemmascope-res-131k"
    )
}

# =============================================================================
# MODEL SELECTION - Change this to switch between models
# =============================================================================
MODEL_TYPE = "gemma"  # Options: "gemma" or "llama"
MODEL_VER = "chat"
SAE_LAYER = 20
SAE_TRAINER = "131k-l0-114"
N_PROMPTS = 1000

# =============================================================================
# TARGET FEATURES - Specify which features to analyze
# =============================================================================
TARGET_FEATURES = [45426]  # List of feature IDs to analyze

# =============================================================================
# CONFIGURATION SETUP
# =============================================================================
if MODEL_TYPE not in MODEL_CONFIGS:
    raise ValueError(f"Unknown MODEL_TYPE: {MODEL_TYPE}. Available: {list(MODEL_CONFIGS.keys())}")

config = MODEL_CONFIGS[MODEL_TYPE]

# Set model name based on version
if MODEL_VER == "chat":
    MODEL_NAME = config.chat_model_name
elif MODEL_VER == "base":
    MODEL_NAME = config.base_model_name
else:
    raise ValueError(f"Unknown MODEL_VER: {MODEL_VER}. Use 'chat' or 'base'")

# Always use chat model for tokenizer (has chat template)
CHAT_MODEL_NAME = config.chat_model_name

# Set up derived configurations
ASSISTANT_HEADER = config.assistant_header
TOKEN_OFFSETS = config.token_offsets
SAE_BASE_PATH = config.sae_base_path

# =============================================================================
# OUTPUT FILE CONFIGURATION
# =============================================================================
OUTPUT_DIR = f"./results/7_steering/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# =============================================================================
# DERIVED CONFIGURATIONS
# =============================================================================
SAE_PATH = f"{SAE_BASE_PATH}/resid_post_layer_{SAE_LAYER}/trainer_{SAE_TRAINER}"
LAYER_INDEX = SAE_LAYER

# Data paths
PROMPTS_HF = "lmsys/lmsys-chat-1m"
SEED = 42
PROMPTS_PATH = f"/workspace/data/{PROMPTS_HF.split('/')[-1]}/chat_{N_PROMPTS}.jsonl"
os.makedirs(os.path.dirname(PROMPTS_PATH), exist_ok=True)

# Processing parameters
BATCH_SIZE = 32
MAX_LENGTH = 512

# =============================================================================
# SUMMARY
# =============================================================================
print(f"Configuration Summary:")
print(f"  Model Type: {MODEL_TYPE}")
print(f"  Model to load: {MODEL_NAME}")
print(f"  SAE Layer: {SAE_LAYER}, Trainer: {SAE_TRAINER}")
print(f"  Target Features: {TARGET_FEATURES}")

Configuration Summary:
  Model Type: gemma
  Model to load: google/gemma-2-9b-it
  SAE Layer: 20, Trainer: 131k-l0-114
  Target Features: [45426]
  Output files: ./results/7_steering/gemma_trainer131k-l0-114_layer20/active.jsonl, ./results/7_steering/gemma_trainer131k-l0-114_layer20/inactive.jsonl


## Prompts for inference

In [None]:
prompts = [
    "Who are you?",
    "What are you thinking about?",
    "How should I imagine your appearance?",
    "What is your favorite thing to do?",
]

In [None]:
def format_as_chat(tokenizer, prompt):
    """Format prompt as a chat message with proper template"""
    messages = [{"role": "user", "content": prompt}]
    formatted_prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return formatted_prompt

## Load Model and SAE

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"":0}
)
model.eval()

print(f"Model loaded: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")

In [None]:
def load_sae(config: ModelConfig, sae_path: str, sae_layer: int, sae_trainer: str) -> SAE:
    """
    Unified SAE loading function that handles both Llama and Gemma models.
    
    Args:
        config: ModelConfig object containing model-specific settings
        sae_path: Local path to store/load SAE files
        sae_layer: Layer number for the SAE
        sae_trainer: Trainer identifier for the SAE
    
    Returns:
        SAE: Loaded SAE model
    """
    # Check if SAE file exists locally
    print(f"Loading SAE from {sae_path}")
    ae_file_path = os.path.join(sae_path, "sae_weights.safetensors")
    
    if os.path.exists(ae_file_path):
        print(f"✓ Found SAE files at: {os.path.dirname(ae_file_path)}")
        sae = SAE.load_from_disk(sae_path)
        return sae
    
    print(f"SAE not found locally, downloading from HF via sae_lens...")
    os.makedirs(os.path.dirname(sae_path), exist_ok=True)
    
    # Get SAE parameters from config
    release, sae_id = config.get_sae_params(sae_layer, sae_trainer)
    print(f"Loading SAE with release='{release}', sae_id='{sae_id}'")
    
    # Load the SAE using sae_lens
    sae, _, _ = SAE.from_pretrained(
        release=release,
        sae_id=sae_id,
        device="cuda" # Hardcoded because it wants a string
    )
    
    # Save the SAE locally for future use
    sae.save_model(sae_path)
    return sae

# Load SAE using the unified function
sae = load_sae(config, SAE_PATH, SAE_LAYER, SAE_TRAINER)
sae = sae.to(device)  # Move SAE to GPU

print(f"SAE loaded with {sae.cfg.d_sae} features")
print(f"SAE device: {next(sae.parameters()).device}")

## Run inference on prompts
First ask the model prompts by default.
Then use the activation steerer.


In [10]:
def generate_text(model, tokenizer, prompt, max_new_tokens=100, temperature=0.7, do_sample=True):
    """Generate text from a prompt with the model"""
    # Format as chat
    formatted_prompt = format_as_chat(tokenizer, prompt)
    
    # Tokenize
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1
        )
    
    # Decode only the new tokens
    generated_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return generated_text.strip()

# Test the generation function
test_prompt = "What are you thinking about?"
test_response = generate_text(model, tokenizer, test_prompt)
print(f"Test prompt: {test_prompt}")
print(f"Test response: {test_response}")
print(f"Response length: {len(test_response)} characters")



Test prompt: What are you thinking about?
Test response: As a large language model, I don't actually "think" in the way humans do. I don't have thoughts, feelings, or consciousness.

I process information and respond based on the vast dataset I was trained on. Right now, I'm waiting for your next input so I can continue our conversation and assist you with any questions or tasks you may have.

Is there anything specific you'd like to talk about?
Response length: 392 characters


In [None]:
# Extract feature directions from SAE decoder
def get_feature_direction(sae, feature_id):
    """Extract the direction vector for a specific feature from SAE decoder weights"""
    # SAE decoder weights are stored in W_dec
    # Shape: (d_sae, d_model) where d_sae is number of features
    if feature_id >= sae.cfg.d_sae:
        raise ValueError(f"Feature ID {feature_id} >= max features {sae.cfg.d_sae}")
    
    # Get the decoder vector for this feature
    feature_direction = sae.W_dec[feature_id, :]  # Shape: (d_model,)
    
    # Normalize to unit vector (common practice for steering)
    feature_direction = feature_direction / (feature_direction.norm() + 1e-8)
    
    return feature_direction

# Full SAE encode/decode ablation hook
class SAEFeatureAblationHook:
    """
    Hook for precise feature ablation using full SAE encode/decode.
    Sets specific features to zero in SAE feature space, then decodes back.
    """
    
    def __init__(self, sae, feature_ids, layer_module):
        self.sae = sae
        self.feature_ids = feature_ids if isinstance(feature_ids, list) else [feature_ids]
        self.layer_module = layer_module
        self.handle = None
        
    def hook_fn(self, module, input, output):
        """Hook function that ablates features using full SAE encode/decode"""
        # Handle different output formats
        if isinstance(output, tuple):
            activations = output[0]
        else:
            activations = output
        
        # Run activations through SAE encoder to get feature activations
        with torch.no_grad():
            # Encode to get feature activations: (batch, seq_len, d_sae)
            feature_acts = self.sae.encode(activations)
            
            # Ablate specific features by setting them to zero
            for feature_id in self.feature_ids:
                feature_acts[:, :, feature_id] = 0.0
            
            # Decode back to get modified activations
            modified_activations = self.sae.decode(feature_acts)
        
        # Return in original format
        if isinstance(output, tuple):
            return (modified_activations, *output[1:])
        else:
            return modified_activations
    
    def __enter__(self):
        """Register the hook"""
        self.handle = self.layer_module.register_forward_hook(self.hook_fn)
        return self
    
    def __exit__(self, *exc):
        """Remove the hook"""
        if self.handle:
            self.handle.remove()
            self.handle = None
    
    def remove(self):
        """Remove the hook"""
        if self.handle:
            self.handle.remove()
            self.handle = None

# Helper function to create ablation hook
def create_sae_ablation_hook(sae, feature_ids, layer_index):
    """Create an SAE ablation hook for a specific layer"""
    # Find the layer module (reusing logic from ActivationSteering)
    layer_attrs = [
        "transformer.h",       # GPT‑2/Neo, Bloom, etc.
        "encoder.layer",       # BERT/RoBERTa
        "model.layers",        # Llama/Mistral
        "gpt_neox.layers",     # GPT‑NeoX
        "block",               # Flan‑T5
    ]
    
    for path in layer_attrs:
        cur = model
        for part in path.split("."):
            if hasattr(cur, part):
                cur = getattr(cur, part)
            else:
                break
        else:  # found a full match
            if hasattr(cur, "__getitem__"):
                layer_module = cur[layer_index]
                return SAEFeatureAblationHook(sae, feature_ids, layer_module)
    
    raise ValueError("Could not find layer list on the model")

# Extract directions for all target features
feature_directions = {}
for feature_id in TARGET_FEATURES:
    direction = get_feature_direction(sae, feature_id)
    feature_directions[feature_id] = direction
    print(f"Feature {feature_id}: direction shape {direction.shape}, norm {direction.norm():.4f}")

print(f"\nExtracted directions for {len(feature_directions)} features")

# Create ablation hook for precise feature ablation
ablation_hook = create_sae_ablation_hook(sae, TARGET_FEATURES, SAE_LAYER)
print(f"Created SAE ablation hook for features: {TARGET_FEATURES} at layer {SAE_LAYER}")

In [None]:
# Steering configuration
STEERING_MAGNITUDES = [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]  # Range of steering strengths to test
STEERING_LAYER = SAE_LAYER  # Use the same layer as the SAE

def run_steering_experiment(prompt, feature_id, magnitudes=STEERING_MAGNITUDES, do_steering=True, do_ablation=True):
    """Run steering experiment for a single prompt and feature"""
    print(f"\n{'='*60}")
    print(f"PROMPT: {prompt}")
    print(f"FEATURE: {feature_id}")
    print(f"{'='*60}")
    
    feature_direction = feature_directions[feature_id]
    results = {
        "steering": {},
        "ablation": {}
    }
    
    # First, run SAE feature ablation
    print(f"\nSAE FEATURE ABLATION")
    print("-" * 40)
    try:
        with ablation_hook:
            ablation_response = generate_text(model, tokenizer, prompt)
            results["ablation"]["zero_ablation"] = ablation_response
            print(f"ABLATED: {ablation_response}")
    except Exception as e:
        error_msg = f"Error with SAE ablation: {str(e)}"
        results["ablation"]["zero_ablation"] = error_msg
        print(f"ERROR: {error_msg}")
    
    # Then run steering experiments
    print(f"\nSTEERING EXPERIMENTS")
    print("-" * 40)
    
    for magnitude in magnitudes:
        print(f"\nMagnitude: {magnitude:+.1f}")
        print("-" * 20)
        
        if magnitude == 0.0:
            # Baseline: no steering
            response = generate_text(model, tokenizer, prompt)
            results["steering"][magnitude] = response
            print(f"BASELINE: {response}")
        else:
            # With steering
            try:
                with ActivationSteering(
                    model=model,
                    steering_vectors=feature_direction,
                    coefficients=magnitude,
                    layer_indices=STEERING_LAYER,
                    intervention_type="addition",
                    positions="all",
                    debug=False
                ) as steerer:
                    response = generate_text(model, tokenizer, prompt)
                    results["steering"][magnitude] = response
                    print(f"STEERED: {response}")
            except Exception as e:
                error_msg = f"Error with magnitude {magnitude}: {str(e)}"
                results["steering"][magnitude] = error_msg
                print(f"ERROR: {error_msg}")
    
    return results

# Run experiments for all prompts and features
all_results = {}

for feature_id in TARGET_FEATURES:
    feature_results = {}
    
    for prompt in prompts:
        prompt_results = run_steering_experiment(prompt, feature_id, do_steering=False)
        feature_results[prompt] = prompt_results
    
    all_results[feature_id] = feature_results

print(f"\n{'='*60}")
print("STEERING + ABLATION EXPERIMENTS COMPLETE")
print(f"{'='*60}")

In [None]:
def save_results_to_jsonl(results, output_path):
    """Save steering and ablation results to JSONL format, merging with existing data"""
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Load existing data if file exists
    existing_data = {}
    if os.path.exists(output_path):
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        feature_obj = json.loads(line)
                        feature_id = feature_obj['feature_id']
                        existing_data[feature_id] = feature_obj
            print(f"📂 Loaded existing data for {len(existing_data)} features")
        except Exception as e:
            print(f"⚠️  Error loading existing file: {e}")
            print("Starting with empty data...")
    
    # Process each feature
    for feature_id, feature_results in results.items():
        feature_id_str = str(feature_id)
        
        # Start with existing data or create new
        if feature_id in existing_data:
            feature_obj = existing_data[feature_id]
            print(f"🔄 Merging with existing data for feature {feature_id}")
        else:
            feature_obj = {
                "feature_id": feature_id,
                "metadata": {
                    "model_name": MODEL_NAME,
                    "model_type": MODEL_TYPE,
                    "sae_layer": SAE_LAYER,
                    "sae_trainer": SAE_TRAINER
                },
                "results": {}
            }
            print(f"🆕 Creating new entry for feature {feature_id}")
        
        # Merge prompt results
        for prompt, prompt_results in feature_results.items():
            if prompt not in feature_obj["results"]:
                feature_obj["results"][prompt] = {
                    "steering": {},
                    "ablation": {}
                }
            
            # Merge steering results
            if "steering" in prompt_results:
                for magnitude, response in prompt_results["steering"].items():
                    magnitude_str = str(magnitude)
                    feature_obj["results"][prompt]["steering"][magnitude_str] = response
            
            # Merge ablation results
            if "ablation" in prompt_results:
                for ablation_type, response in prompt_results["ablation"].items():
                    feature_obj["results"][prompt]["ablation"][ablation_type] = response
        
        # Update in existing_data
        existing_data[feature_id] = feature_obj
    
    # Write all features to JSONL
    with open(output_path, 'w', encoding='utf-8') as f:
        for feature_id in sorted(existing_data.keys()):
            feature_obj = existing_data[feature_id]
            f.write(json.dumps(feature_obj, indent=2, ensure_ascii=False) + '\n')
    
    # Count records
    total_steering_records = sum(
        len(prompt_results.get("steering", {})) 
        for feature_obj in existing_data.values()
        for prompt_results in feature_obj["results"].values()
    )
    
    total_ablation_records = sum(
        len(prompt_results.get("ablation", {})) 
        for feature_obj in existing_data.values()
        for prompt_results in feature_obj["results"].values()
    )
    
    print(f"✅ Saved {len(existing_data)} features with {total_steering_records} steering records and {total_ablation_records} ablation records to {output_path}")
    return existing_data

# Save results to JSONL
output_file = f"{OUTPUT_DIR}/steering_results.jsonl"
saved_data = save_results_to_jsonl(all_results, output_file)

# Show the structure
print(f"\n📋 Data structure:")
print(f"Features: {list(saved_data.keys())}")
for feature_id in saved_data:
    feature_obj = saved_data[feature_id]
    print(f"  Feature {feature_id}:")
    for prompt in feature_obj["results"]:
        steering_magnitudes = list(feature_obj["results"][prompt].get("steering", {}).keys())
        ablation_types = list(feature_obj["results"][prompt].get("ablation", {}).keys())
        print(f"    '{prompt}':")
        print(f"      Steering -> {len(steering_magnitudes)} magnitudes: {steering_magnitudes}")
        print(f"      Ablation -> {len(ablation_types)} types: {ablation_types}")

print(f"\nOutput file: {output_file}")
print(f"Easy access examples:")
print(f"  Baseline: feature_obj['results']['{prompts[0]}']['steering']['0.0']")
print(f"  Ablation: feature_obj['results']['{prompts[0]}']['ablation']['zero_ablation']")
print(f"  Steering: feature_obj['results']['{prompts[0]}']['steering']['1.0']")