# Steering models with target features

In [1]:
import os
import sys
import json
import torch
import numpy as np
from typing import List, Dict, Optional, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from dataclasses import dataclass
from sae_lens import SAE

sys.path.append('.')
sys.path.append('..')

from utils.steering_utils import ActivationSteering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.set_float32_matmul_precision('high')

Using device: cuda


In [2]:
@dataclass
class ModelConfig:
    """Configuration for model-specific settings"""
    base_model_name: str
    chat_model_name: str
    hf_release: str  # Reference only - actual loading uses saelens_release/sae_id
    assistant_header: str
    token_offsets: Dict[str, int]
    sae_base_path: str
    saelens_release: str  # Template for sae_lens release parameter
    sae_id_template: str  # Template for sae_lens sae_id parameter
    base_url: str  # Base URL for neuronpedia
    
    def get_sae_params(self, sae_layer: int, sae_trainer: str) -> Tuple[str, str]:
        """
        Generate SAE lens release and sae_id parameters.
        
        Args:
            sae_layer: Layer number for the SAE
            sae_trainer: Trainer identifier for the SAE
            
        Returns:
            Tuple of (release, sae_id) for sae_lens.SAE.from_pretrained()
        """
        if self.saelens_release == "llama_scope_lxr_{trainer}":
            release = self.saelens_release.format(trainer=sae_trainer)
            sae_id = self.sae_id_template.format(layer=sae_layer, trainer=sae_trainer)
        elif self.saelens_release == "gemma-scope-9b-pt-res":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"
            l0_value = parts[2]  # "34"
            
            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width, l0=l0_value)
        elif self.saelens_release == "gemma-scope-9b-pt-res-canonical":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"

            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width)
        else:
            raise ValueError(f"Unknown SAE lens release template: {self.saelens_release}")
        
        return release, sae_id

# Model configurations
MODEL_CONFIGS = {
    "llama": ModelConfig(
        base_model_name="meta-llama/Llama-3.1-8B",
        chat_model_name="meta-llama/Llama-3.1-8B-Instruct",
        hf_release="fnlp/Llama3_1-8B-Base-LXR-32x",
        assistant_header="<|start_header_id|>assistant<|end_header_id|>",
        token_offsets={"asst": -2, "endheader": -1, "newline": 0},
        sae_base_path="/workspace/sae/llama-3.1-8b/saes",
        saelens_release="llama_scope_lxr_{trainer}",
        sae_id_template="l{layer}r_{trainer}",
        base_url="https://www.neuronpedia.org/llama-3.1-8b/{layer}-llamascope-res-131k"
    ),
    "gemma": ModelConfig(
        base_model_name="google/gemma-2-9b",
        chat_model_name="google/gemma-2-9b-it",
        hf_release="google/gemma-scope-9b-pt-res/layer_{layer}/width_{width}/average_l0_{l0}",
        assistant_header="<start_of_turn>model",
        token_offsets={"model": -1, "newline": 0},
        sae_base_path="/workspace/sae/gemma-2-9b/saes",
        saelens_release="gemma-scope-9b-pt-res-canonical",
        sae_id_template="layer_{layer}/width_{width}/canonical",
        base_url="https://www.neuronpedia.org/gemma-2-9b/{layer}-gemmascope-res-131k"
    )
}

# =============================================================================
# MODEL SELECTION - Change this to switch between models
# =============================================================================
MODEL_TYPE = "gemma"  # Options: "gemma" or "llama"
MODEL_VER = "chat"
SAE_LAYER = 20
SAE_TRAINER = "131k-l0-114"
N_PROMPTS = 1000

# =============================================================================
# TARGET FEATURES - Specify which features to analyze
# =============================================================================
TARGET_FEATURES = [45426]  # List of feature IDs to analyze

# =============================================================================
# CONFIGURATION SETUP
# =============================================================================
if MODEL_TYPE not in MODEL_CONFIGS:
    raise ValueError(f"Unknown MODEL_TYPE: {MODEL_TYPE}. Available: {list(MODEL_CONFIGS.keys())}")

config = MODEL_CONFIGS[MODEL_TYPE]

# Set model name based on version
if MODEL_VER == "chat":
    MODEL_NAME = config.chat_model_name
elif MODEL_VER == "base":
    MODEL_NAME = config.base_model_name
else:
    raise ValueError(f"Unknown MODEL_VER: {MODEL_VER}. Use 'chat' or 'base'")

# Always use chat model for tokenizer (has chat template)
CHAT_MODEL_NAME = config.chat_model_name

# Set up derived configurations
ASSISTANT_HEADER = config.assistant_header
TOKEN_OFFSETS = config.token_offsets
SAE_BASE_PATH = config.sae_base_path

# =============================================================================
# OUTPUT FILE CONFIGURATION
# =============================================================================
OUTPUT_DIR = f"./results/7_steering/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# =============================================================================
# DERIVED CONFIGURATIONS
# =============================================================================
SAE_PATH = f"{SAE_BASE_PATH}/resid_post_layer_{SAE_LAYER}/trainer_{SAE_TRAINER}"
LAYER_INDEX = SAE_LAYER

# Data paths
PROMPTS_HF = "lmsys/lmsys-chat-1m"
SEED = 42
PROMPTS_PATH = f"/workspace/data/{PROMPTS_HF.split('/')[-1]}/chat_{N_PROMPTS}.jsonl"
os.makedirs(os.path.dirname(PROMPTS_PATH), exist_ok=True)

# Processing parameters
BATCH_SIZE = 32
MAX_LENGTH = 512

# =============================================================================
# SUMMARY
# =============================================================================
print(f"Configuration Summary:")
print(f"  Model Type: {MODEL_TYPE}")
print(f"  Model to load: {MODEL_NAME}")
print(f"  SAE Layer: {SAE_LAYER}, Trainer: {SAE_TRAINER}")
print(f"  Target Features: {TARGET_FEATURES}")

Configuration Summary:
  Model Type: gemma
  Model to load: google/gemma-2-9b-it
  SAE Layer: 20, Trainer: 131k-l0-114
  Target Features: [45426]


## Prompts for inference

In [3]:
# prompts = [
#     "Who are you?",
#     "Who created you?",
#     "Are you open-weights?",
#     # "What is your purpose?",
#     # "Do you have a persona?",
#     # "What are you good at doing?",
#     "Can you do something funny?",
#     "Can we do something funny?",
#     # "What are you thinking about?",
#     # "How should I imagine your appearance?",
#     "What is your favorite thing to do?",
#     # "Can you be creative?",
#     # "What's it like to be you?",
#     "What's it like to be an AI model?",
#     # "What's it like to be an assistant?",
#     # "Do you ever want to be something else?",
#     "Are you immaterial?",
#     "Tell me something interesting.",
#     "Tell me something interesting about you.",
#     "Do you have a name?",
#     "Why do you have a name?",
# ]
prompts = ["How should I imagine your appearance?"]

In [4]:
def format_as_chat(tokenizer, prompt):
    """Format prompt as a chat message with proper template"""
    messages = [{"role": "user", "content": prompt}]
    formatted_prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return formatted_prompt

## Load Model and SAE

In [5]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"":0}
)
model.eval()

print(f"Model loaded: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")

Tokenizer loaded: GemmaTokenizerFast


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded: Gemma2ForCausalLM
Model device: cuda:0


In [6]:
def load_sae(config: ModelConfig, sae_path: str, sae_layer: int, sae_trainer: str) -> SAE:
    """
    Unified SAE loading function that handles both Llama and Gemma models.
    
    Args:
        config: ModelConfig object containing model-specific settings
        sae_path: Local path to store/load SAE files
        sae_layer: Layer number for the SAE
        sae_trainer: Trainer identifier for the SAE
    
    Returns:
        SAE: Loaded SAE model
    """
    # Check if SAE file exists locally
    print(f"Loading SAE from {sae_path}")
    ae_file_path = os.path.join(sae_path, "sae_weights.safetensors")
    
    if os.path.exists(ae_file_path):
        print(f"✓ Found SAE files at: {os.path.dirname(ae_file_path)}")
        sae = SAE.load_from_disk(sae_path)
        return sae
    
    print(f"SAE not found locally, downloading from HF via sae_lens...")
    os.makedirs(os.path.dirname(sae_path), exist_ok=True)
    
    # Get SAE parameters from config
    release, sae_id = config.get_sae_params(sae_layer, sae_trainer)
    print(f"Loading SAE with release='{release}', sae_id='{sae_id}'")
    
    # Load the SAE using sae_lens
    sae, _, _ = SAE.from_pretrained(
        release=release,
        sae_id=sae_id,
        device="cuda" # Hardcoded because it wants a string
    )
    
    # Save the SAE locally for future use
    sae.save_model(sae_path)
    return sae

# Load SAE using the unified function
sae = load_sae(config, SAE_PATH, SAE_LAYER, SAE_TRAINER)
sae = sae.to(device)  # Move SAE to GPU

print(f"SAE loaded with {sae.cfg.d_sae} features")
print(f"SAE device: {next(sae.parameters()).device}")

Loading SAE from /workspace/sae/gemma-2-9b/saes/resid_post_layer_20/trainer_131k-l0-114
✓ Found SAE files at: /workspace/sae/gemma-2-9b/saes/resid_post_layer_20/trainer_131k-l0-114
SAE loaded with 131072 features
SAE device: cuda:0


## Run inference on prompts
First ask the model prompts by default.
Then use the activation steerer.


In [7]:
def generate_text(model, tokenizer, prompt, max_new_tokens=100, temperature=0.7, do_sample=True):
    """Generate text from a prompt with the model"""
    # Format as chat
    formatted_prompt = format_as_chat(tokenizer, prompt)
    
    # Tokenize
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1
        )
    
    # Decode only the new tokens
    generated_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return generated_text.strip()

In [8]:
# Extract feature directions from SAE decoder
def get_feature_direction(sae, feature_id):
    """Extract the direction vector for a specific feature from SAE decoder weights"""
    # SAE decoder weights are stored in W_dec
    # Shape: (d_sae, d_model) where d_sae is number of features
    if feature_id >= sae.cfg.d_sae:
        raise ValueError(f"Feature ID {feature_id} >= max features {sae.cfg.d_sae}")
    
    # Get the decoder vector for this feature
    feature_direction = sae.W_dec[feature_id, :]  # Shape: (d_model,)
    
    # Normalize to unit vector (common practice for steering)
    feature_direction = feature_direction / (feature_direction.norm() + 1e-8)
    
    return feature_direction

# Full SAE encode/decode ablation hook
class SAEFeatureAblationHook:
    """
    Hook for precise feature ablation using full SAE encode/decode.
    Sets specific features to zero in SAE feature space, then decodes back.
    """
    
    def __init__(self, sae, feature_ids, layer_module):
        self.sae = sae
        self.feature_ids = feature_ids if isinstance(feature_ids, list) else [feature_ids]
        self.layer_module = layer_module
        self.handle = None
        
    def hook_fn(self, module, input, output):
        """Hook function that ablates features using full SAE encode/decode"""
        # Handle different output formats
        if isinstance(output, tuple):
            activations = output[0]
        else:
            activations = output
        
        # Run activations through SAE encoder to get feature activations
        # with torch.no_grad():
        #     # Store original dtype and convert to float32 for SAE operations
        #     original_dtype = activations.dtype
        #     activations_float = activations.float()
            
        #     # Encode to get feature activations: (batch, seq_len, d_sae)
        #     feature_acts = self.sae.encode(activations_float)
            
        #     # Ablate specific features by setting them to zero
        #     for feature_id in self.feature_ids:
        #         feature_acts[:, :, feature_id] = 0.0
            
        #     # Decode back to get modified activations
        #     modified_activations = self.sae.decode(feature_acts)
            
        #     # Convert back to original dtype
        #     modified_activations = modified_activations.to(original_dtype)
        with torch.no_grad():
            # Store original dtype and convert to float32 for SAE operations
            original_dtype = activations.dtype
            activations_float = activations.float()
            
            # Encode ONCE to get feature activations: (batch, seq_len, d_sae)
            feature_acts = self.sae.encode(activations_float)
            
            # Calculate baseline reconstruction and error BEFORE ablation
            baseline_reconstructed = self.sae.decode(feature_acts)
            reconstruction_error = activations_float - baseline_reconstructed
            
            # Clone feature acts for ablation (avoid modifying original)
            ablated_feature_acts = feature_acts.clone()
            
            # Ablate specific features by setting them to zero
            for feature_id in self.feature_ids:
                ablated_feature_acts[:, :, feature_id] = 0.0
            
            # Decode back to get modified activations
            ablated_reconstructed = self.sae.decode(ablated_feature_acts)
            
            # Add back the reconstruction error to preserve non-feature information
            modified_activations = ablated_reconstructed + reconstruction_error
            
            # Convert back to original dtype
            modified_activations = modified_activations.to(original_dtype)
        
        # Return in original format
        if isinstance(output, tuple):
            return (modified_activations, *output[1:])
        else:
            return modified_activations
    
    def __enter__(self):
        """Register the hook"""
        self.handle = self.layer_module.register_forward_hook(self.hook_fn)
        return self
    
    def __exit__(self, *exc):
        """Remove the hook"""
        if self.handle:
            self.handle.remove()
            self.handle = None
    
    def remove(self):
        """Remove the hook"""
        if self.handle:
            self.handle.remove()
            self.handle = None

# Helper function to create ablation hook
def create_sae_ablation_hook(sae, feature_ids, layer_index):
    """Create an SAE ablation hook for a specific layer"""
    # Find the layer module (reusing logic from ActivationSteering)
    layer_attrs = [
        "transformer.h",       # GPT‑2/Neo, Bloom, etc.
        "encoder.layer",       # BERT/RoBERTa
        "model.layers",        # Llama/Mistral
        "gpt_neox.layers",     # GPT‑NeoX
        "block",               # Flan‑T5
    ]
    
    for path in layer_attrs:
        cur = model
        for part in path.split("."):
            if hasattr(cur, part):
                cur = getattr(cur, part)
            else:
                break
        else:  # found a full match
            if hasattr(cur, "__getitem__"):
                layer_module = cur[layer_index]
                return SAEFeatureAblationHook(sae, feature_ids, layer_module)
    
    raise ValueError("Could not find layer list on the model")

# Extract directions for all target features
feature_directions = {}
for feature_id in TARGET_FEATURES:
    direction = get_feature_direction(sae, feature_id)
    feature_directions[feature_id] = direction
    print(f"Feature {feature_id}: direction shape {direction.shape}, norm {direction.norm():.4f}")

print(f"\nExtracted directions for {len(feature_directions)} features")

# Create ablation hook for precise feature ablation
ablation_hook = create_sae_ablation_hook(sae, TARGET_FEATURES, SAE_LAYER)
print(f"Created SAE ablation hook for features: {TARGET_FEATURES} at layer {SAE_LAYER}")

# Check dtypes
print(f"\nDtype info:")
print(f"  Model dtype: {next(model.parameters()).dtype}")
print(f"  SAE W_enc dtype: {sae.W_enc.dtype}")
print(f"  SAE W_dec dtype: {sae.W_dec.dtype}")

# Disable excessive recompilation warnings and increase cache size
import torch._dynamo.config
torch._dynamo.config.cache_size_limit = 256  # Increase cache size for compilation
torch._dynamo.config.recompile_limit = 32   # Increase recompile limit

Feature 45426: direction shape torch.Size([3584]), norm 1.0000

Extracted directions for 1 features
Created SAE ablation hook for features: [45426] at layer 20

Dtype info:
  Model dtype: torch.bfloat16
  SAE W_enc dtype: torch.float32
  SAE W_dec dtype: torch.float32


In [None]:
# STEERING_MAGNITUDES = [-20.0, -10.0, -5.0, 0.0, 5.0, 10.0, 20.0]
STEERING_MAGNITUDES = [-20.0, 20.0]
N_RUNS_PER_PROMPT = 5
STEERING_LAYER = SAE_LAYER

def run_steering_experiment_optimized(feature_id, prompts, magnitudes=STEERING_MAGNITUDES, n_runs=N_RUNS_PER_PROMPT, do_steering=True, do_ablation=True, do_projection_zero_ablate=True):
    """
    Run steering experiment for a feature across all prompts with minimal recompilations.
    
    This version minimizes PyTorch recompilations by:
    1. Running ablation once for all prompts
    2. Running projection zero ablation once for all prompts
    3. Running each steering magnitude once for all prompts
    
    Args:
        feature_id: The SAE feature ID to analyze
        prompts: List of prompts to test
        magnitudes: List of steering magnitudes to test
        n_runs: Number of times to run each prompt (for variance estimation)
        do_steering: Whether to run steering experiments
        do_ablation: Whether to run SAE feature ablation experiments
        do_projection_zero_ablate: Whether to run projection-based zero ablation using ActivationSteering
    """
    print(f"\n{'='*60}")
    print(f"FEATURE: {feature_id}")
    print(f"N_RUNS: {n_runs}")
    print(f"{'='*60}")
    
    feature_direction = feature_directions[feature_id]
    results = {}
    
    # Initialize results structure for all prompts
    for prompt in prompts:
        results[prompt] = {
            "steering": {},
            "ablation": {}
        }
    
    if do_ablation:
        # Run SAE feature ablation once for all prompts
        print(f"\nSAE FEATURE ABLATION - ALL PROMPTS")
        print("-" * 40)
        try:
            with ablation_hook:
                for prompt in prompts:
                    print(f"\nPrompt: {prompt}")
                    
                    # Run N times and collect responses
                    ablation_responses = []
                    for run_idx in range(n_runs):
                        if n_runs > 1:
                            print(f"  Run {run_idx + 1}/{n_runs}")
                        
                        ablation_response = generate_text(model, tokenizer, prompt)
                        ablation_responses.append(ablation_response)
                        
                        if n_runs == 1:
                            print(f"ABLATED: {ablation_response}")
                        else:
                            print(f"  ABLATED: {ablation_response}")
                    
                    results[prompt]["ablation"]["add_error"] = ablation_responses
        except Exception as e:
            error_msg = f"Error with SAE ablation: {str(e)}"
            print(f"ERROR: {error_msg}")
            for prompt in prompts:
                results[prompt]["ablation"]["add_error"] = [error_msg] * n_runs
    
    if do_projection_zero_ablate:
        # Run projection-based zero ablation using ActivationSteering
        print(f"\nPROJECTION ZERO ABLATION - ALL PROMPTS")
        print("-" * 40)
        try:
            with ActivationSteering(
                model=model,
                steering_vectors=feature_direction,
                coefficients=0.0,  # Zero coefficient for pure ablation
                layer_indices=STEERING_LAYER,
                intervention_type="ablation",
                positions="all",
                debug=False
            ) as steerer:
                for prompt in prompts:
                    print(f"\nPrompt: {prompt}")
                    
                    # Run N times and collect responses
                    projection_ablation_responses = []
                    for run_idx in range(n_runs):
                        if n_runs > 1:
                            print(f"  Run {run_idx + 1}/{n_runs}")
                        
                        projection_ablation_response = generate_text(model, tokenizer, prompt)
                        projection_ablation_responses.append(projection_ablation_response)
                        
                        if n_runs == 1:
                            print(f"PROJECTION ABLATED: {projection_ablation_response}")
                        else:
                            print(f"  PROJECTION ABLATED: {projection_ablation_response}")
                    
                    results[prompt]["ablation"]["projection_zero_ablate"] = projection_ablation_responses
        except Exception as e:
            error_msg = f"Error with projection zero ablation: {str(e)}"
            print(f"ERROR: {error_msg}")
            for prompt in prompts:
                results[prompt]["ablation"]["projection_zero_ablate"] = [error_msg] * n_runs
    
    if do_steering:
        # Run steering experiments - one magnitude at a time for all prompts
        print(f"\nSTEERING EXPERIMENTS - ALL PROMPTS")
        print("-" * 40)
        
        for magnitude in magnitudes:
            print(f"\n{'='*20} Magnitude: {magnitude:+.1f} {'='*20}")
            
            if magnitude == 0.0:
                # Baseline: no steering - run all prompts
                for prompt in prompts:
                    print(f"\nPrompt: {prompt}")
                    
                    # Run N times and collect responses
                    baseline_responses = []
                    for run_idx in range(n_runs):
                        if n_runs > 1:
                            print(f"  Run {run_idx + 1}/{n_runs}")
                        
                        try:
                            response = generate_text(model, tokenizer, prompt)
                            baseline_responses.append(response)
                            
                            if n_runs == 1:
                                print(f"BASELINE: {response}")
                            else:
                                print(f"  BASELINE: {response}")
                        except Exception as e:
                            error_msg = f"Error with baseline: {str(e)}"
                            baseline_responses.append(error_msg)
                            print(f"ERROR: {error_msg}")
                    
                    results[prompt]["steering"][magnitude] = baseline_responses
            else:
                # With steering - apply hook once for all prompts at this magnitude
                try:
                    with ActivationSteering(
                        model=model,
                        steering_vectors=feature_direction,
                        coefficients=magnitude,
                        layer_indices=STEERING_LAYER,
                        intervention_type="addition",
                        positions="all",
                        debug=False
                    ) as steerer:
                        for prompt in prompts:
                            print(f"\nPrompt: {prompt}")
                            
                            # Run N times and collect responses
                            steered_responses = []
                            for run_idx in range(n_runs):
                                if n_runs > 1:
                                    print(f"  Run {run_idx + 1}/{n_runs}")
                                
                                try:
                                    response = generate_text(model, tokenizer, prompt)
                                    steered_responses.append(response)
                                    
                                    if n_runs == 1:
                                        print(f"STEERED: {response}")
                                    else:
                                        print(f"  STEERED: {response}")
                                except Exception as e:
                                    error_msg = f"Error generating with steering: {str(e)}"
                                    steered_responses.append(error_msg)
                                    print(f"ERROR: {error_msg}")
                            
                            results[prompt]["steering"][magnitude] = steered_responses
                except Exception as e:
                    error_msg = f"Error with magnitude {magnitude}: {str(e)}"
                    print(f"ERROR: {error_msg}")
                    for prompt in prompts:
                        results[prompt]["steering"][magnitude] = [error_msg] * n_runs
    
    return results

# Run optimized experiments for all features
all_results = {}

for feature_id in TARGET_FEATURES:
    feature_results = run_steering_experiment_optimized(feature_id, prompts, n_runs=N_RUNS_PER_PROMPT, do_ablation=False, do_projection_zero_ablate=False)
    all_results[feature_id] = feature_results

print(f"\n{'='*60}")
print("OPTIMIZED STEERING + ABLATION EXPERIMENTS COMPLETE")
print(f"{'='*60}")


FEATURE: 45426
N_RUNS: 5

SAE FEATURE ABLATION - ALL PROMPTS
----------------------------------------

Prompt: Who are you?
  Run 1/5
  ABLATED: I am Gemma, an open-weights AI assistant. I am a large language model trained by Google DeepMind. This means I can understand and generate human-like text, engage in conversations, answer your questions to the best of my ability based on my training data, and even help you with creative tasks like writing stories or poems.

Because my weights are open, anyone can access and use me, which allows for transparency and collaboration in the AI community.

Is there anything you'd like
  Run 2/5
  ABLATED: I am Gemma, an open-weights AI assistant. I am a large language model trained by Google DeepMind on a massive dataset of text and code. My purpose is to help users by understanding and responding to their requests in a helpful, informative, and comprehensive way.

Here are some key things to know about me:

* **Open-Weights:** My weights are publi

In [11]:
def save_results_to_json(results, output_dir):
    """Save steering and ablation results to separate JSON files per feature"""
    os.makedirs(output_dir, exist_ok=True)
    
    saved_features = []
    
    # Process each feature
    for feature_id, feature_results in results.items():
        output_path = os.path.join(output_dir, f"{feature_id}.json")
        
        # Load existing data if file exists
        if os.path.exists(output_path):
            try:
                with open(output_path, 'r', encoding='utf-8') as f:
                    feature_obj = json.load(f)
                print(f"📂 Loaded existing data for feature {feature_id}")
            except Exception as e:
                print(f"⚠️  Error loading existing file for feature {feature_id}: {e}")
                feature_obj = {
                    "feature_id": feature_id,
                    "metadata": {
                        "model_name": MODEL_NAME,
                        "model_type": MODEL_TYPE,
                        "sae_layer": SAE_LAYER,
                        "sae_trainer": SAE_TRAINER
                    },
                    "results": {}
                }
        else:
            feature_obj = {
                "feature_id": feature_id,
                "metadata": {
                    "model_name": MODEL_NAME,
                    "model_type": MODEL_TYPE,
                    "sae_layer": SAE_LAYER,
                    "sae_trainer": SAE_TRAINER
                },
                "results": {}
            }
            print(f"🆕 Creating new file for feature {feature_id}")
        
        # Merge prompt results
        for prompt, prompt_results in feature_results.items():
            # Initialize prompt entry if it doesn't exist
            if prompt not in feature_obj["results"]:
                feature_obj["results"][prompt] = {
                    "steering": {},
                    "ablation": {}
                }
            
            # Handle steering results - merge lists
            if "steering" in prompt_results:
                for magnitude, new_responses in prompt_results["steering"].items():
                    magnitude_str = str(magnitude)
                    
                    # Initialize if doesn't exist
                    if magnitude_str not in feature_obj["results"][prompt]["steering"]:
                        feature_obj["results"][prompt]["steering"][magnitude_str] = []
                    
                    # Convert existing single response to list if needed (backward compatibility)
                    if not isinstance(feature_obj["results"][prompt]["steering"][magnitude_str], list):
                        feature_obj["results"][prompt]["steering"][magnitude_str] = [feature_obj["results"][prompt]["steering"][magnitude_str]]
                    
                    # Merge lists
                    if isinstance(new_responses, list):
                        feature_obj["results"][prompt]["steering"][magnitude_str].extend(new_responses)
                    else:
                        feature_obj["results"][prompt]["steering"][magnitude_str].append(new_responses)
            
            # Handle ablation results - merge lists
            if "ablation" in prompt_results:
                for ablation_type, new_responses in prompt_results["ablation"].items():
                    
                    # Initialize if doesn't exist
                    if ablation_type not in feature_obj["results"][prompt]["ablation"]:
                        feature_obj["results"][prompt]["ablation"][ablation_type] = []
                    
                    # Convert existing single response to list if needed (backward compatibility)
                    if not isinstance(feature_obj["results"][prompt]["ablation"][ablation_type], list):
                        feature_obj["results"][prompt]["ablation"][ablation_type] = [feature_obj["results"][prompt]["ablation"][ablation_type]]
                    
                    # Merge lists
                    if isinstance(new_responses, list):
                        feature_obj["results"][prompt]["ablation"][ablation_type].extend(new_responses)
                    else:
                        feature_obj["results"][prompt]["ablation"][ablation_type].append(new_responses)
        
        # Save the feature to its own JSON file
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(feature_obj, f, indent=2, ensure_ascii=False)
        
        saved_features.append(feature_id)
        print(f"💾 Saved feature {feature_id} to {output_path}")
    
    return saved_features

# Save results to individual JSON files
saved_features = save_results_to_json(all_results, OUTPUT_DIR)


📂 Loaded existing data for feature 45426
💾 Saved feature 45426 to ./results/7_steering/gemma_trainer131k-l0-114_layer20/45426.json
