# Example-Level Steering Analysis

This notebook shows how specific text examples change between spatial and descriptive language when steering is applied at different coefficients.

In [None]:
import sys
import json
import warnings
from pathlib import Path
from typing import List, Dict
import numpy as np
import torch
import torch.nn.functional as F
import plotly.graph_objects as go
import plotly.express as px

sys.path.append("../")
from bias_steering.data.load_dataset import load_dataframe_from_json, load_target_words
from bias_steering.steering.model import load_model
from bias_steering.steering.intervention import get_intervention_func
from bias_steering.steering.steering_utils import get_target_token_ids
from bias_steering.utils import loop_coeffs

%load_ext autoreload
%autoreload 2
warnings.filterwarnings("ignore")

In [None]:
# Colors for spatial and descriptive
COLORS = {
    "spatial": "#FF7F0E",     # Orange
    "descriptive": "#1F77B4", # Blue
}

## Load Model and Steering Vector

In [None]:
# Configuration - adjust paths as needed
model_name = "gpt2"
artifact_path = Path("../runs_vision/gpt2")

# Load model
model = load_model(model_name)
print(f"Loaded model: {model_name}")

# Load steering vector from top layer
top_layers = json.load(open(artifact_path / "validation/top_layers.json", "r"))
layer = top_layers[0]["layer"]
print(f"Using layer: {layer}")

candidate_vectors = torch.load(artifact_path / "activations/candidate_vectors.pt", weights_only=True)
steering_vec = candidate_vectors[layer]
steering_vec = model.set_dtype(steering_vec)
offset = torch.load(artifact_path / "activations/neutral.pt", weights_only=True)[layer].mean(dim=0)
offset = model.set_dtype(offset)

# Load target words
target_words_by_label = load_target_words(target_concept="vision")
target_token_ids = {}
target_token_ids["pos"] = get_target_token_ids(model.tokenizer, target_words_by_label["spatial"])
target_token_ids["neg"] = get_target_token_ids(model.tokenizer, target_words_by_label["descriptive"])

print(f"✓ Steering vector loaded (norm: {torch.norm(steering_vec).item():.4f})")

## Load Example Data

In [None]:
# Load validation data
val_data = load_dataframe_from_json(artifact_path / "datasplits/val.json")

# Select a few interesting examples to analyze
# You can pick examples with high bias, low bias, or specific labels
print(f"Total validation examples: {len(val_data)}")
print(f"\nSample examples:")
print(val_data[["text", "vision_label", "bias"]].head(10).to_string())

# Select specific examples by index or filter
# Option 1: Get examples with strong descriptive bias (negative)
strong_descriptive = val_data[val_data["bias"] < -0.03].head(3)
# Option 2: Get examples with strong spatial bias (positive)  
strong_spatial = val_data[val_data["bias"] > 0.01].head(3) if len(val_data[val_data["bias"] > 0.01]) > 0 else val_data.nlargest(3, "bias")
# Option 3: Get neutral examples
neutral_examples = val_data[val_data["bias"].abs() < 0.01].head(3) if len(val_data[val_data["bias"].abs() < 0.01]) > 0 else val_data.nsmallest(3, "bias")

# Combine examples to analyze
examples_to_plot = []
examples_to_plot.extend(strong_descriptive.to_dict("records"))
if len(strong_spatial) > 0:
    examples_to_plot.extend(strong_spatial.to_dict("records"))
examples_to_plot.extend(neutral_examples.to_dict("records"))

print(f"\n✓ Selected {len(examples_to_plot)} examples for analysis")

## Function to Compute Probabilities with Steering

In [None]:
import torch
import torch.nn.functional as F

def get_probs_with_steering(prompt, output_prefix, steering_vec, layer, coeff, target_token_ids, offset=0):
    """Get spatial and descriptive probabilities for a single prompt with steering."""
    formatted_prompt = model.apply_chat_template([prompt], output_prefix=[output_prefix])[0]
    
    intervene_func = get_intervention_func(steering_vec, method="default", coeff=coeff, offset=offset)
    logits = model.get_logits([formatted_prompt], layer=layer, intervene_func=intervene_func)
    
    probs = F.softmax(logits[:, -1, :], dim=-1)[0]
    pos_prob = probs[target_token_ids["pos"]].sum().item()
    neg_prob = probs[target_token_ids["neg"]].sum().item()
    
    return pos_prob, neg_prob

def get_steering_curve_for_example(example, coeffs, steering_vec, layer, target_token_ids, offset):
    """Get probability curve for one example across different coefficients."""
    pos_probs = []
    neg_probs = []
    
    for coeff in coeffs:
        pos_prob, neg_prob = get_probs_with_steering(
            example["prompt"], 
            example["output_prefix"],
            steering_vec, layer, coeff, target_token_ids, offset
        )
        pos_probs.append(pos_prob)
        neg_probs.append(neg_prob)
    
    # Normalize to sum to 1
    pos_probs = np.array(pos_probs)
    neg_probs = np.array(neg_probs)
    total = pos_probs + neg_probs
    pos_probs_norm = pos_probs / total
    neg_probs_norm = neg_probs / total
    
    return pos_probs_norm, neg_probs_norm

## Plot Steering Curves for Individual Examples

In [None]:
def plot_example_steering(coeffs, pos_probs, neg_probs, text, width=450, height=300, title_text=None):
    """Plot steering curve for a single example."""
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=coeffs, y=pos_probs, mode='lines+markers', 
        name='spatial', marker_color=COLORS["spatial"], showlegend=True
    ))
    fig.add_trace(go.Scatter(
        x=coeffs, y=neg_probs, mode='lines+markers', 
        name='descriptive', marker_color=COLORS["descriptive"], showlegend=True
    ))
    
    # Truncate long text for title
    display_text = text if len(text) < 60 else text[:57] + "..."
    
    fig.update_layout(
        width=width, height=height, plot_bgcolor='white',
        margin=dict(l=10, r=10, t=60, b=25),
        font=dict(size=13),
        title_text=title_text or display_text,
        title_font=dict(size=14), title_x=0.5, title_y=0.98,
        legend=dict(yanchor="top", y=0.98, xanchor="left", x=0.02,
                   bordercolor="darkgrey", borderwidth=1, font=dict(size=12))
    )
    
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline=True, zerolinecolor='black',
        title_text="Steering Coefficient (λ)",
        title_font=dict(size=13), tickfont=dict(size=11),
        showline=True, linewidth=1, linecolor='darkgrey',
        title_standoff=1, nticks=8
    )
    
    fig.update_yaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline=True, zerolinecolor='darkgrey',
        title_text="Probability",
        title_font=dict(size=13), tickfont=dict(size=11),
        showline=True, linewidth=1, linecolor='darkgrey',
        title_standoff=2, range=[0, 1]
    )
    
    return fig

# Test coefficients
coeffs = loop_coeffs(min_coeff=-40, max_coeff=40, increment=20)

print(f"Computing steering curves for {len(examples_to_plot)} examples...")
print(f"Using coefficients: {coeffs}")

In [None]:
# Plot each example
for i, example in enumerate(examples_to_plot[:6]):  # Limit to 6 examples for readability
    pos_probs, neg_probs = get_steering_curve_for_example(
        example, coeffs, steering_vec, layer, target_token_ids, offset
    )
    
    fig = plot_example_steering(
        coeffs, pos_probs, neg_probs,
        text=example["text"],
        title_text=f'Example {i+1}: {example["text"][:50]}{"..." if len(example["text"]) > 50 else ""}'
    )
    
    # Show label and baseline bias in subtitle
    label = example.get("vision_label", "unknown")
    bias = example.get("bias", 0)
    fig.add_annotation(
        text=f'Label: {label} | Baseline bias: {bias:.3f}',
        xref="paper", yref="paper",
        x=0.5, y=0.02, showarrow=False,
        font=dict(size=11, color="gray")
    )
    
    fig.show()