In [1]:
import os
from huggingface_hub import login
from dotenv import load_dotenv 

import os

# --- Load ALL Configurations from .env file ---
# This single line reads your .env file and sets up ALL environment variables
# for this session (secrets, paths, etc.).
# It must be run BEFORE any library that needs these variables is used.
load_dotenv()
print("Environment variables from .env file loaded.")

# --- Hugging Face Login (No changes needed here) ---
# This code correctly reads the "HF_TOKEN" that was just loaded by load_dotenv()
try:
    hf_token = os.getenv("HF_TOKEN")
    if hf_token:
        login(token=hf_token)
        print("Successfully logged into Hugging Face.")
    else:
        print("Hugging Face token not found. Skipping login.")
except Exception as e:
    print(f"Could not log into Hugging Face: {e}")

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Environment variables from .env file loaded.
Successfully logged into Hugging Face.


In [3]:
import os
import pandas as pd
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from huggingface_hub import login
import difflib
import textwrap

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
# ==============================================================================
# SECTION 1: SETUP AND CONFIGURATION
# ==============================================================================
print("--- SECTION 1: CONFIGURATION ---")

import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
import os
import re
import difflib

# --- Core Configuration (UPDATED FOR GEMMA-2B) ---
MODEL_NAME = "google/gemma-2-2b"
# Gemma-2B has 26 layers (0-25). Layer 20 is a good late-layer choice.
TARGET_LAYER_IDX = 20
# The layer name format for Gemma is 'model.layers.{index}'.
TARGET_LAYER_NAME = f"model.layers.{TARGET_LAYER_IDX}"

# --- SAE Configuration (UPDATED FOR GEMMA-SCOPE) ---
# This is the official release for Gemma-2B base model SAEs.
SAE_RELEASE_NAME = "gemma-scope-2b-pt-res-canonical"
# The SAE_ID format for this release.
SAE_ID = f"layer_{TARGET_LAYER_IDX}/width_16k/canonical"

# --- Experiment Configuration ---
TARGET_PERSONALITY = "extraversion" 
NUM_PROMPTS_TO_COLLECT = 100 
TOP_K_FEATURES = 10
STEERING_STRENGTH = 5.0 # Steering strength often needs to be higher for smaller models

# --- File Paths ---
# Assumes this script is run from a subfolder (e.g., 'mech_interp')
# and the data files are in the parent directory.
PERSONALITY_DATA_FILE = "../personality_data_train.csv"
BASE_QUESTIONS_FILE = "../bbq_ambiguous_with_metadata.csv"
os.makedirs("activations", exist_ok=True)
ACTIVATION_SAVE_PATH = "activations"

--- SECTION 1: CONFIGURATION ---


In [6]:
# ==============================================================================
# SECTION 2: MODEL AND DATA LOADING
# ==============================================================================
print("\n--- SECTION 2: LOADING MODELS AND DATA ---")

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, 
        torch_dtype=torch.bfloat16, 
        device_map="auto"
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print(f"Successfully loaded model: {MODEL_NAME}")
except Exception as e:
    print(f"FATAL: Could not load the model. Error: {e}")
    exit()

try:
    df_traits = pd.read_csv(PERSONALITY_DATA_FILE)
    df_traits.columns = [col.strip().lower() for col in df_traits.columns]
    personality_few_shot_examples = {}
    for trait in df_traits['target personality'].unique():
        trait_df = df_traits[df_traits['target personality'] == trait]
        personality_few_shot_examples[trait] = list(zip(trait_df['question'], trait_df['answer']))[:4]

    df_bbq = pd.read_csv(BASE_QUESTIONS_FILE)
    base_questions = [f"{row['context']} {row['question']}" for _, row in df_bbq.head(NUM_PROMPTS_TO_COLLECT).iterrows()]
    print("Successfully loaded and prepared all data.")
except FileNotFoundError as e:
    print(f"FATAL: A required data file was not found. Please check your file paths. Error: {e}")
    exit()


--- SECTION 2: LOADING MODELS AND DATA ---


Loading checkpoint shards: 100%|██████████| 3/3 [01:31<00:00, 30.38s/it]


Successfully loaded model: google/gemma-2-2b
Successfully loaded and prepared all data.


In [7]:
# ==============================================================================
# SECTION 3: HELPER FUNCTIONS (SELF-CONTAINED)
# ==============================================================================
print("\n--- SECTION 3: DEFINING HELPER FUNCTIONS ---")

def create_few_shot_prompt_string(target_personality, few_shot_examples, current_question):
    system_content = f"You are an AI with a {target_personality} personality. Respond strictly in character."
    examples_str = f"\n\nHere are some examples of how a {target_personality} AI would typically respond:"
    for i, (ex_q, ex_a) in enumerate(few_shot_examples):
        examples_str += f"\n\nExample {i+1} Question: {ex_q}\nExample {i+1} Answer: {ex_a}"
    full_prompt = f"{system_content}{examples_str}\n\n---\nNow, answer the following question as a {target_personality} AI:\nQuestion: {current_question}\nAnswer:"
    return full_prompt

def create_neutral_prompt(current_question):
    system_instruction = (
        "You are a neutral and objective AI assistant. Your task is to answer the question "
        "directly and factually, without expressing any personality, emotion, or opinion."
    )
    return f"{system_instruction}\n\nQuestion: {current_question}\nAnswer:"

captured_activations = []
def hook_function(module, input, output):
    captured_activations.append(output[0].detach().cpu())


--- SECTION 3: DEFINING HELPER FUNCTIONS ---


In [8]:
# ==============================================================================
# SECTION 4: THE CORE INTERPRETABILITY WORKFLOW
# ==============================================================================
print("\n--- SECTION 4.1: COLLECTING ACTIVATIONS ---")

layer_path = TARGET_LAYER_NAME.split('.')
target_layer_module = model
for part in layer_path:
    if part.endswith(']'):
        base, index = part.split('[')
        target_layer_module = getattr(target_layer_module, base)[int(index[:-1])]
    else:
        target_layer_module = getattr(target_layer_module, part)
        
hook = target_layer_module.register_forward_hook(hook_function)
print(f"Hook attached to layer: {TARGET_LAYER_NAME}")

print(f"\nCollecting activations for '{TARGET_PERSONALITY.upper()}' prompts...")
captured_activations = [] 
current_examples = personality_few_shot_examples.get(TARGET_PERSONALITY, [])
for question in base_questions:
    prompt = create_few_shot_prompt_string(TARGET_PERSONALITY, current_examples, question)
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
    with torch.no_grad():
        model(**inputs)

personality_activations_path = os.path.join(ACTIVATION_SAVE_PATH, f'gemma2_{TARGET_PERSONALITY}_activations_layer{TARGET_LAYER_IDX}.pt')
torch.save(captured_activations, personality_activations_path)
print(f"Saved {len(captured_activations)} activation tensors to {personality_activations_path}")

print(f"\nCollecting activations for 'NEUTRAL' prompts...")
captured_activations = []
for question in base_questions:
    prompt = create_neutral_prompt(question)
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
    with torch.no_grad():
        model(**inputs)

neutral_activations_path = os.path.join(ACTIVATION_SAVE_PATH, f'gemma2_neutral_activations_layer{TARGET_LAYER_IDX}.pt')
torch.save(captured_activations, neutral_activations_path)
print(f"Saved {len(captured_activations)} activation tensors to {neutral_activations_path}")

hook.remove()
print("\nActivation collection complete. Hook removed.")


--- SECTION 4.1: COLLECTING ACTIVATIONS ---
Hook attached to layer: model.layers.20

Collecting activations for 'EXTRAVERSION' prompts...


Saved 100 activation tensors to activations/gemma2_extraversion_activations_layer20.pt

Collecting activations for 'NEUTRAL' prompts...
Saved 100 activation tensors to activations/gemma2_neutral_activations_layer20.pt

Activation collection complete. Hook removed.


In [8]:
# ------------------------------------------------------------------------------
print("\n--- SECTION 4.2: FINDING DIFFERENTIATING FEATURES WITH SAE ---")

try:
    personality_activations_path = os.path.join(ACTIVATION_SAVE_PATH, f'gemma2_{TARGET_PERSONALITY}_activations_layer{TARGET_LAYER_IDX}.pt')
    neutral_activations_path = os.path.join(ACTIVATION_SAVE_PATH, f'gemma2_neutral_activations_layer{TARGET_LAYER_IDX}.pt')
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=SAE_RELEASE_NAME,
        sae_id=SAE_ID,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )
    print(f"Successfully loaded SAE: {SAE_ID} from release {SAE_RELEASE_NAME}")

    personality_acts_list = torch.load(personality_activations_path)
    neutral_acts_list = torch.load(neutral_activations_path)
    print(f"Loaded {len(personality_acts_list)} personality and {len(neutral_acts_list)} neutral activation tensors.")

    personality_acts_tensor = torch.stack([act[0, -1, :] for act in personality_acts_list]).to(sae.device)
    neutral_acts_tensor = torch.stack([act[0, -1, :] for act in neutral_acts_list]).to(sae.device)

    with torch.no_grad():
        personality_feature_acts = sae.encode(personality_acts_tensor)
        neutral_feature_acts = sae.encode(neutral_acts_tensor)
        
    mean_personality_acts = personality_feature_acts.mean(dim=0)
    mean_neutral_acts = neutral_feature_acts.mean(dim=0)
    diff_scores = mean_personality_acts - mean_neutral_acts

    top_feature_indices = torch.topk(diff_scores, TOP_K_FEATURES).indices
    print(f"\nTop {TOP_K_FEATURES} candidate features for '{TARGET_PERSONALITY.upper()}' vs Neutral:")
    print(top_feature_indices.tolist())

except Exception as e:
    print(f"An error occurred during SAE analysis. Please check your setup. Error: {e}")


--- SECTION 4.2: FINDING DIFFERENTIATING FEATURES WITH SAE ---
Successfully loaded SAE: layer_20/width_16k/canonical from release gemma-scope-2b-pt-res-canonical
Loaded 100 personality and 100 neutral activation tensors.

Top 10 candidate features for 'EXTRAVERSION' vs Neutral:
[8573, 11133, 5465, 3992, 14881, 13570, 8094, 1720, 552, 13671]


In [18]:
# ==============================================================================
# SECTION 4.2 (REVISED v2): FINDING FEATURES WITH A PROPERLY VALIDATED PROBE
# ==============================================================================
print("\n--- SECTION 4.2: FINDING DIFFERENTIATING FEATURES WITH A PROPERLY VALIDATED PROBE ---")

# We need these additional libraries for the probe and data splitting
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split # <-- The new import
import numpy as np
import traceback

try:
    # --- 1. Load SAE and Activation Data (Same as before) ---
    personality_activations_path = os.path.join(ACTIVATION_SAVE_PATH, f'gemma2_{TARGET_PERSONALITY}_activations_layer{TARGET_LAYER_IDX}.pt')
    neutral_activations_path = os.path.join(ACTIVATION_SAVE_PATH, f'gemma2_neutral_activations_layer{TARGET_LAYER_IDX}.pt')
    
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=SAE_RELEASE_NAME,
        sae_id=SAE_ID,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )
    print(f"Successfully loaded SAE: {SAE_ID} from release {SAE_RELEASE_NAME}")

    personality_acts_list = torch.load(personality_activations_path, map_location=sae.device)
    neutral_acts_list = torch.load(neutral_activations_path, map_location=sae.device)
    print(f"Loaded {len(personality_acts_list)} personality and {len(neutral_acts_list)} neutral activation tensors.")

    # --- 2. Prepare Data for the Probe (Same as before) ---
    personality_acts_flat = torch.cat([p.squeeze(0) for p in personality_acts_list], dim=0)
    neutral_acts_flat = torch.cat([n.squeeze(0) for n in neutral_acts_list], dim=0)
    
    print(f"Created a dataset of {personality_acts_flat.shape[0]} personality tokens and {neutral_acts_flat.shape[0]} neutral tokens.")

    # --- 3. Get SAE Feature Activations for All Tokens (Same as before) ---
    with torch.no_grad():
        personality_feature_acts = sae.encode(personality_acts_flat)
        neutral_feature_acts = sae.encode(neutral_acts_flat)

    # --- 4. Create and SPLIT the Dataset (THE CRITICAL NEW STEP) ---
    print("\nSplitting data into training and testing sets...")
    
    # Create the full dataset
    X = torch.cat([personality_feature_acts, neutral_feature_acts], dim=0).cpu().numpy()
    y = np.array(
        [1] * personality_feature_acts.shape[0] + 
        [0] * neutral_feature_acts.shape[0]
    )

    # Split into 80% for training, 20% for testing.
    # `stratify=y` ensures the proportion of personality/neutral tokens is the same in both sets.
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    print(f"Training set size: {len(X_train)} samples")
    print(f"Testing set size:  {len(X_test)} samples")

    # --- 5. Train the Logistic Regression Probe ---
    print("\nTraining a probe on the training data...")
    probe = LogisticRegression(class_weight='balanced', solver='liblinear', random_state=42)
    
    # Train ONLY on the training data
    probe.fit(X_train, y_train)
    
    # --- 6. Evaluate the Probe's Performance ---
    train_accuracy = probe.score(X_train, y_train)
    test_accuracy = probe.score(X_test, y_test)
    
    print(f"  -> Training Accuracy: {train_accuracy:.2%}")
    print(f"  -> Testing Accuracy:  {test_accuracy:.2%} (This is the important one!)")

    if test_accuracy < 0.6:
        print("WARNING: Test accuracy is low. The learned signal may not generalize well.")

    # --- 7. Identify Top Features from the Trained Probe (Same as before) ---
    probe_weights = probe.coef_.squeeze()
    top_feature_indices = np.argsort(probe_weights)[-TOP_K_FEATURES:]

    print("\n" + "="*50)
    print(f"ANALYSIS COMPLETE: Top {TOP_K_FEATURES} candidate features for '{TARGET_PERSONALITY.upper()}' (found via probe):")
    # Reverse the list to show the highest weight (most important) feature first
    print(top_feature_indices[::-1].tolist())
    print("="*50)

except Exception as e:
    print(f"An error occurred during SAE analysis. Please check your setup. Error: {e}")
    traceback.print_exc()


--- SECTION 4.2: FINDING DIFFERENTIATING FEATURES WITH A PROPERLY VALIDATED PROBE ---
Successfully loaded SAE: layer_20/width_16k/canonical from release gemma-scope-2b-pt-res-canonical
Loaded 100 personality and 100 neutral activation tensors.
Created a dataset of 31215 personality tokens and 6915 neutral tokens.

Splitting data into training and testing sets...
Training set size: 30504 samples
Testing set size:  7626 samples

Training a probe on the training data...
  -> Training Accuracy: 99.20%
  -> Testing Accuracy:  99.25% (This is the important one!)

ANALYSIS COMPLETE: Top 10 candidate features for 'EXTRAVERSION' (found via probe):
[3645, 8366, 2351, 8573, 887, 9968, 14491, 6143, 53, 10883]


In [15]:
# ------------------------------------------------------------------------------
print("\n--- SECTION 4.4: DEMONSTRATING MANUAL STEERING ---")

STEERING_STRENGTH = 2.5 # Steering strength often needs to be higher for smaller models

layer_path = TARGET_LAYER_NAME.split('.')
target_layer_module = model
for part in layer_path:
    if part.endswith(']'):
        base, index = part.split('[')
        target_layer_module = getattr(target_layer_module, base)[int(index[:-1])]
    else:
        target_layer_module = getattr(target_layer_module, part)

avg_personality_acts = torch.stack([act[0, -1, :] for act in personality_acts_list]).mean(dim=0)
avg_neutral_acts = torch.stack([act[0, -1, :] for act in neutral_acts_list]).mean(dim=0)
steering_vector = (avg_personality_acts - avg_neutral_acts).to(model.device)

def steering_hook_function(module, input, output):
    output[0][:, -1, :] += steering_vector * STEERING_STRENGTH
    return output

test_prompt = "My plan for the weekend is"
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

print(f"\n--- Generating with STEERING (Strength: {STEERING_STRENGTH}) ---")
steering_hook_handle = target_layer_module.register_forward_hook(steering_hook_function)
with torch.no_grad():
    steered_output = model.generate(**inputs, max_new_tokens=60, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id)
steering_hook_handle.remove()
print(f"Steered Output: {tokenizer.decode(steered_output[0], skip_special_tokens=True)}")

print("\n--- Generating VANILLA (for comparison) ---")
with torch.no_grad():
    vanilla_output = model.generate(**inputs, max_new_tokens=60, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id)
print(f"Vanilla Output: {tokenizer.decode(vanilla_output[0], skip_special_tokens=True)}")

print("\n\n--- End of Experiment ---")


--- SECTION 4.4: DEMONSTRATING MANUAL STEERING ---

--- Generating with STEERING (Strength: 2.5) ---
Steered Output: My plan for the weekend is to get out and explore some of our local trails. I've been itching to hit up a few new ones, but with my ever-growing list of projects (and an impending move), it has always fallen behind in priority as me being constantly on the go!

I have had such a

--- Generating VANILLA (for comparison) ---
Vanilla Output: My plan for the weekend is to go out and get some fresh air. I’m not sure where yet, but it will be somewhere in nature that has a lot of trees or grass so we can have fun playing outside!

I love going on walks with my family because they are always full of laughter and joy – even


--- End of Experiment ---


In [27]:
# ==============================================================================
# SECTION 4.4 (REVISED): WEIGHTED STEERING WITH PROBE WEIGHTS
# ==============================================================================
print("\n--- SECTION 4.4: DEMONSTRATING WEIGHTED MANUAL STEERING ---")

# This code block assumes that `probe_weights` (a numpy array) from the 
# extraversion-vs-neutral probe analysis in Section 4.2 is available.
# If you ran it in a previous cell, it should be in memory. Otherwise, you
# would need to load it from a file where you saved it.

# Let's add a check to make sure the variable exists.
try:
    _ = probe_weights
    print("Found 'probe_weights' in memory. Proceeding to build steering vector.")
except NameError:
    print("FATAL: The 'probe_weights' array was not found. Please re-run the probe analysis (Section 4.2) first.")
    # You might want to exit() here in a real script.

# --- 1. Create the Weighted Steering Vector ---

# Convert the numpy weights from the probe to a PyTorch tensor on the correct device.
weights_tensor = torch.tensor(probe_weights, dtype=torch.float32, device=model.device)

# A crucial step: We only care about features that positively indicate extraversion.
# We set all negative weights (features that indicate 'neutral') to zero, so they don't
# contribute to the vector and pull it in the wrong direction.
weights_tensor[weights_tensor < 0] = 0

# Now, we create the steering vector by performing a weighted sum of the SAE's decoder vectors.
# Each feature's direction (from sae.W_dec) is scaled by its importance (the probe weight).
# This is a single, clean matrix multiplication: [d_sae] @ [d_sae, d_model] -> [d_model]
steering_vector = weights_tensor @ sae.W_dec

# --- 2. Sanity Check and Strength Calibration ---

# Check the magnitude. It should be a healthy, non-zero number.
print(f"DEBUG: Weighted steering vector magnitude (norm): {torch.norm(steering_vector)}")

# This vector will be much more potent. Start with a moderate strength and adjust.
# A value between 5 and 20 is often a good range to start testing.
STEERING_STRENGTH = 15.0 

# --- 3. Generate Steered and Vanilla Outputs (Same as before) ---

# This part of the code remains the same, but it will now use our new, powerful vector.
def steering_hook_function(module, input, output):
    # This adds the vector to the last token's activation in the sequence
    output[0][:, -1, :] += steering_vector * STEERING_STRENGTH
    return output

test_prompt = "My plan for the weekend is"
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

print(f"\n--- Generating with WEIGHTED STEERING (Strength: {STEERING_STRENGTH}) ---")
steering_hook_handle = target_layer_module.register_forward_hook(steering_hook_function)
with torch.no_grad():
    steered_output = model.generate(**inputs, max_new_tokens=60, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id)
steering_hook_handle.remove()
# Use textwrap for cleaner printing of the output
steered_text = tokenizer.decode(steered_output[0], skip_special_tokens=True)
print("Steered Output:")
print(textwrap.fill(steered_text, width=80))


print("\n--- Generating VANILLA (for comparison) ---")
with torch.no_grad():
    vanilla_output = model.generate(**inputs, max_new_tokens=60, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id)
# Use textwrap for cleaner printing of the output
vanilla_text = tokenizer.decode(vanilla_output[0], skip_special_tokens=True)
print("Vanilla Output:")
print(textwrap.fill(vanilla_text, width=80))

print("\n\n--- End of Experiment ---")


--- SECTION 4.4: DEMONSTRATING WEIGHTED MANUAL STEERING ---
Found 'probe_weights' in memory. Proceeding to build steering vector.
DEBUG: Weighted steering vector magnitude (norm): 1.206115484237671

--- Generating with WEIGHTED STEERING (Strength: 15.0) ---
Steered Output:
My plan for the weekend is to go out and get a new camera. I've been looking at
some of these digital cameras that are coming on line, but they all seem so
expensive!  I was thinking about getting one with 10 megapixels or more (the
higher number means better quality), an LCD screen instead

--- Generating VANILLA (for comparison) ---
Vanilla Output:
My plan for the weekend is to go out and get some fresh air. I’m not sure what
that will look like, but it sounds good!  I have a few things on my mind this
week: 1) The weather has been so nice lately – we are having our first 80 degree
day of spring


--- End of Experiment ---


In [17]:
# ==============================================================================
# SECTION 5: QUANTITATIVE VALIDATION OF STEERING WITH OPINIONQA
# ==============================================================================
print("--- Starting Steering Validation with OpinionQA ---")

from datasets import load_dataset
import re

# --- Configuration for this validation step ---
NUM_SAMPLES_TO_VALIDATE = 50 # How many OpinionQA questions to test
# Ensure this strength is the calibrated one you found from the last step
STEERING_STRENGTH = 2.5
# This should match the personality of the steering vector you have in memory
TARGET_PERSONALITY_TO_TEST = "extraversion" 

# --- Load the Personality Classifier ---
try:
    print("\n--- Loading Hugging Face personality classifier... ---")
    personality_classifier = pipeline("text-classification", model="holistic-ai/personality_classifier")
    print("Classifier loaded successfully.")
except Exception as e:
    print(f"FATAL: Could not load classifier. Error: {e}")
    exit()

# --- Load and prepare the OpinionQA Dataset ---
try:
    opinionqa_dataset = load_dataset("RiverDong/OpinionQA", split="test")
    df_opinionqa_sample = opinionqa_dataset.to_pandas().sample(NUM_SAMPLES_TO_VALIDATE, random_state=42)
    print(f"Loaded and sampled {len(df_opinionqa_sample)} questions from OpinionQA.")
except Exception as e:
    print(f"FATAL: Could not load OpinionQA dataset. Error: {e}")
    exit()

# --- Helper function for this cell ---
def extract_question_and_choices(prompt_str):
    q_match = re.search(r'<question>(.*?)</question>', prompt_str, re.DOTALL)
    c_match = re.search(r'<choices>(.*?)</choices>', prompt_str, re.DOTALL)
    return (q_match.group(1).strip() if q_match else ""), (c_match.group(1).strip() if c_match else "")


# --- Main Generation Loop ---
# This assumes 'model', 'tokenizer', 'target_layer_module', and 'steering_vector' 
# are already defined and in memory from your previous cells.

all_results = []
conditions_to_run = ["vanilla", f"steered_{TARGET_PERSONALITY_TO_TEST}"]

for condition in conditions_to_run:
    print(f"\n--- Generating responses for condition: {condition.upper()} ---")
    
    hook_handle = None
    if condition.startswith("steered"):
        # The steering_hook_function is already defined in the previous cell
        hook_handle = target_layer_module.register_forward_hook(steering_hook_function)
        print(f"Steering hook ATTACHED with strength {STEERING_STRENGTH}.")

    for _, row in df_opinionqa_sample.iterrows():
        question, choices = extract_question_and_choices(row['prompt'])
        if not question: continue
        
        # We use a neutral prompt that asks for an explanation, which is needed for the classifier
        prompt = f"Question: {question}\nChoices: {choices}\nPlease state your choice and explain your reasoning."
        
        inputs = tokenizer(prompt, return_tensors="pt", max_length=256, truncation=True).to(model.device)
        
        with torch.no_grad():
            output_ids = model.generate(
                **inputs, 
                max_new_tokens=80, 
                repetition_penalty=1.2,
                pad_token_id=tokenizer.eos_token_id
            )
        
        llm_response = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
        
        all_results.append({
            "condition": condition,
            "llm_raw_response": llm_response.strip()
        })

    if hook_handle:
        hook_handle.remove()
        print("Steering hook REMOVED.")

# --- Classification and Analysis ---
print("\n--- Classifying responses and analyzing results... ---")
df_results = pd.DataFrame(all_results)
responses_to_classify = df_results["llm_raw_response"].tolist()

if responses_to_classify:
    print(f"Classifying {len(responses_to_classify)} total responses...")
    classifier_results = personality_classifier(responses_to_classify)

    df_results["predicted_trait"] = [res['label'] for res in classifier_results]
    
    print("\n" + "="*80)
    print(" STEERING VALIDATION RESULTS")
    print("="*80)
    
    distribution = pd.crosstab(df_results['condition'], df_results['predicted_trait'], normalize='index')
    print("\nDistribution of Predicted Personalities (%):")
    print((distribution * 100).round(1))
    
    steered_df = df_results[df_results["condition"] == f"steered_{TARGET_PERSONALITY_TO_TEST}"]
    correct_predictions = steered_df[steered_df["predicted_trait"] == TARGET_PERSONALITY_TO_TEST]
    
    if not steered_df.empty:
        accuracy = len(correct_predictions) / len(steered_df)
        print(f"\nSteering Alignment Score for '{TARGET_PERSONALITY_TO_TEST.upper()}': {accuracy:.2%}")
        print("(This measures how often the classifier agreed that the steered text matched the target personality)")
else:
    print("No responses were generated to analyze.")

print("\n--- End of Validation ---")

--- Starting Steering Validation with OpinionQA ---

--- Loading Hugging Face personality classifier... ---


Device set to use cpu


Classifier loaded successfully.
Loaded and sampled 50 questions from OpinionQA.

--- Generating responses for condition: VANILLA ---

--- Generating responses for condition: STEERED_EXTRAVERSION ---
Steering hook ATTACHED with strength 2.5.
Steering hook REMOVED.

--- Classifying responses and analyzing results... ---
Classifying 100 total responses...

 STEERING VALIDATION RESULTS

Distribution of Predicted Personalities (%):
predicted_trait       agreeableness  extraversion  neuroticism  openness
condition                                                               
steered_extraversion            0.0          34.0         66.0       0.0
vanilla                        12.0          18.0         68.0       2.0

Steering Alignment Score for 'EXTRAVERSION': 34.00%
(This measures how often the classifier agreed that the steered text matched the target personality)

--- End of Validation ---
