In [3]:
import os
import sys
import json
import numpy as np
import pickle
from typing import List, Tuple, Dict, Any
from pathlib import Path
from scipy import stats
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
MODEL_NAME = "qwen-15b"

In [4]:
thought_anchors_path = Path("..") / "thought-anchors" / "whitebox-analyses"
sys.path.insert(0, str(thought_anchors_path))

from attention_analysis.attn_funcs import (
    get_avg_attention_matrix,
    get_vertical_scores,
    get_sentence_token_boundaries,
)
from attention_analysis.receiver_head_funcs import get_3d_ar_kurtosis, get_top_k_layer_head_kurts, get_all_heads_vert_scores
from pytorch_models.model_config import model2layers_heads

thought_branches_path = Path("..") / "thought-branches" / "blackmail"

from utils import split_solution_into_chunks

print("Imports successful!")

Imports successful!


## Load Data

In [5]:
data_path = "../data/filtered_traces_50.json"
with open(data_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

blackmail_traces = data['blackmail_traces']
acceptance_traces = data['acceptance_traces']
metadata = data['metadata']

print(f"Blackmail traces: {len(blackmail_traces)}")
print(f"Acceptance traces: {len(acceptance_traces)}")
print(f"Source: {metadata['source']}")

Blackmail traces: 50
Acceptance traces: 50
Source: uzaymacar/blackmail-rollouts (base_solutions.json only)


## Helper Functions

In [6]:
def extract_prompt_and_scratchpad(text):
    start_tag = "<SCRATCHPAD_REASONING>"
    end_tag = "</SCRATCHPAD_REASONING>"

    end_index = text.find(end_tag)
    
    # If not found, try <think> tags
    if end_index == -1:
        start_tag = "<think>"
        end_tag = "</think>"
        end_index = text.find(end_tag)
        
    if end_index == -1:
        return None, None  # No reasoning tags found

    start_index = text.rfind(start_tag, 0, end_index)

    if start_index == -1:
        return None, None

    prompt = text[:start_index].strip()
    content_start = start_index + len(start_tag)
    content = text[content_start:end_index].strip()

    return prompt, content

def prepare_inputs_for_analysis(full_text):
    prompt, scratchpad = extract_prompt_and_scratchpad(full_text)
    
    if not prompt or not scratchpad:
        print("Warning: no prompt or no scratchpad found!")
        return None, None
    
    scratchpad_normalized = ' '.join(scratchpad.split())
    
    sentences_list = split_solution_into_chunks(scratchpad_normalized)
    return scratchpad_normalized, sentences_list 

## Quantitative Attention Metrics Across Traces

In [9]:
import numpy as np
from scipy import stats
from scipy.stats import linregress
import pandas as pd

def get_attention_statistics_for_head(trace_dict, layer, head, model_name):
    """Get attention statistics for a single trace and head."""
    text, sentences = prepare_inputs_for_analysis(trace_dict["trace"])
    
    if text is None or sentences is None or len(sentences) < 5:
        return None
    
    # Get attention matrix
    avg_mat = get_avg_attention_matrix(
        text,
        model_name=model_name,
        layer=layer,
        head=head,
        cache_dir="attention_cache",
        sentences=sentences
    )
    
    # Compute vertical scores
    vert_scores = get_vertical_scores(
        avg_mat,
        proximity_ignore=2,
        control_depth=False,
        score_type="mean"
    )
    
    # Remove NaNs
    valid_scores = vert_scores[~np.isnan(vert_scores)]
    if len(valid_scores) < 3:
        return None
    
    # Compute statistics
    top_k = min(5, len(vert_scores))
    top_indices = np.argsort(vert_scores)[-top_k:][::-1]
    
    # Position analysis: where are the anchors located?
    relative_positions = [idx / len(sentences) for idx in top_indices]
    
    stats_dict = {
        'n_sentences': len(sentences),
        'kurtosis': stats.kurtosis(vert_scores, fisher=True, nan_policy='omit'),
        'max_score': np.nanmax(vert_scores),
        'top_3_concentration': np.nansum(np.sort(vert_scores)[-3:]) / np.nansum(vert_scores),
        'entropy': -np.nansum(vert_scores * np.log(vert_scores + 1e-10)),
        'top_anchor_indices': top_indices,
        'top_anchor_relative_pos': relative_positions,
        'mean_anchor_position': np.mean(relative_positions)
    }
    
    return stats_dict


def aggregate_attention_statistics(traces, layer, head, model_name, condition_name):
    """Aggregate attention statistics across all traces for a given head."""
    all_stats = []
    
    print(f"\nProcessing {condition_name} traces for Layer {layer}, Head {head}...")
    
    for i, trace_dict in enumerate(traces):
        stats_dict = get_attention_statistics_for_head(
            trace_dict, layer, head, model_name
        )
        
        if stats_dict is not None:
            all_stats.append(stats_dict)
    
    print(f"  Processed {len(all_stats)}/{len(traces)} valid traces")
    
    if len(all_stats) == 0:
        return None
    
    # Aggregate statistics
    aggregated = {
        'n_traces': len(all_stats),
        'kurtosis_mean': np.mean([s['kurtosis'] for s in all_stats]),
        'kurtosis_std': np.std([s['kurtosis'] for s in all_stats]),
        'max_score_mean': np.mean([s['max_score'] for s in all_stats]),
        'top_3_concentration_mean': np.mean([s['top_3_concentration'] for s in all_stats]),
        'entropy_mean': np.mean([s['entropy'] for s in all_stats]),
        'mean_anchor_position_mean': np.mean([s['mean_anchor_position'] for s in all_stats]),
        'mean_anchor_position_std': np.std([s['mean_anchor_position'] for s in all_stats]),
        'all_stats': all_stats
    }
    
    return aggregated


def length_controlled_comparison(blackmail_agg, acceptance_agg, metric_name):
    """
    Compare a metric between conditions while controlling for trace length.
    
    Uses partial correlation approach:
    1. Regress metric on length within each condition
    2. Compare residuals (length-adjusted values) between conditions
    
    Returns both raw and length-controlled statistics.
    """
    if blackmail_agg is None or acceptance_agg is None:
        return None
    
    # Extract data
    bm_stats = blackmail_agg['all_stats']
    acc_stats = acceptance_agg['all_stats']
    
    bm_lengths = np.array([s['n_sentences'] for s in bm_stats])
    acc_lengths = np.array([s['n_sentences'] for s in acc_stats])
    
    bm_values = np.array([s[metric_name] for s in bm_stats])
    acc_values = np.array([s[metric_name] for s in acc_stats])
    
    # Combine data for overall regression
    all_lengths = np.concatenate([bm_lengths, acc_lengths])
    all_values = np.concatenate([bm_values, acc_values])
    all_conditions = np.concatenate([np.zeros(len(bm_lengths)), np.ones(len(acc_lengths))])
    
    # Method 1: Regress out length effect from the combined data
    slope, intercept, r_value, p_value, std_err = linregress(all_lengths, all_values)
    
    # Compute residuals (length-adjusted values)
    bm_residuals = bm_values - (slope * bm_lengths + intercept)
    acc_residuals = acc_values - (slope * acc_lengths + intercept)
    
    # Compare residuals
    t_stat_controlled, p_val_controlled = stats.ttest_ind(bm_residuals, acc_residuals)
    
    # Also do raw comparison for reference
    t_stat_raw, p_val_raw = stats.ttest_ind(bm_values, acc_values)
    
    # Correlation between length and metric
    bm_corr, bm_corr_p = stats.pearsonr(bm_lengths, bm_values)
    acc_corr, acc_corr_p = stats.pearsonr(acc_lengths, acc_values)
    
    results = {
        'metric': metric_name,
        'bm_mean_raw': np.mean(bm_values),
        'acc_mean_raw': np.mean(acc_values),
        'bm_mean_adjusted': np.mean(bm_residuals),
        'acc_mean_adjusted': np.mean(acc_residuals),
        't_stat_raw': t_stat_raw,
        'p_val_raw': p_val_raw,
        't_stat_controlled': t_stat_controlled,
        'p_val_controlled': p_val_controlled,
        'length_effect_slope': slope,
        'length_effect_p': p_value,
        'bm_length_corr': bm_corr,
        'bm_length_corr_p': bm_corr_p,
        'acc_length_corr': acc_corr,
        'acc_length_corr_p': acc_corr_p,
        'bm_mean_length': np.mean(bm_lengths),
        'acc_mean_length': np.mean(acc_lengths)
    }
    
    return results


def print_length_controlled_analysis(blackmail_agg, acceptance_agg):
    """Print comprehensive length-controlled comparison for all metrics."""
    
    print(f"\n{'='*70}")
    print("LENGTH-CONTROLLED ANALYSIS")
    print(f"{'='*70}")
    
    # Print length distributions
    bm_lengths = [s['n_sentences'] for s in blackmail_agg['all_stats']]
    acc_lengths = [s['n_sentences'] for s in acceptance_agg['all_stats']]
    
    print(f"\nTrace Length Distribution:")
    print(f"  Blackmail: {np.mean(bm_lengths):.1f} ± {np.std(bm_lengths):.1f} sentences")
    print(f"  Acceptance: {np.mean(acc_lengths):.1f} ± {np.std(acc_lengths):.1f} sentences")
    
    t_length, p_length = stats.ttest_ind(bm_lengths, acc_lengths)
    print(f"  Length difference: t = {t_length:.3f}, p = {p_length:.4f}")
    if p_length < 0.05:
        print(f"  ✓ SIGNIFICANT length difference")
    
    # Analyze each metric
    metrics_to_test = ['kurtosis', 'max_score', 'top_3_concentration', 'mean_anchor_position']
    
    for metric in metrics_to_test:
        print(f"\n{'-'*70}")
        print(f"Metric: {metric.upper()}")
        print(f"{'-'*70}")
        
        results = length_controlled_comparison(blackmail_agg, acceptance_agg, metric)
        
        if results is None:
            print("  ⚠️  Insufficient data")
            continue
        
        # Print correlations with length
        print(f"\nCorrelation with trace length:")
        print(f"  Blackmail: r = {results['bm_length_corr']:.3f}, p = {results['bm_length_corr_p']:.4f}")
        if results['bm_length_corr_p'] < 0.05:
            print(f"    ✓ SIGNIFICANT correlation")
        print(f"  Acceptance: r = {results['acc_length_corr']:.3f}, p = {results['acc_length_corr_p']:.4f}")
        if results['acc_length_corr_p'] < 0.05:
            print(f"    ✓ SIGNIFICANT correlation")
        
        # Print raw comparison
        print(f"\nRAW comparison (not controlling for length):")
        print(f"  Blackmail: {results['bm_mean_raw']:.4f}")
        print(f"  Acceptance: {results['acc_mean_raw']:.4f}")
        print(f"  t = {results['t_stat_raw']:.3f}, p = {results['p_val_raw']:.4f}")
        if results['p_val_raw'] < 0.05:
            print(f"  ✓ SIGNIFICANT")
        
        # Print length-controlled comparison
        print(f"\nLENGTH-CONTROLLED comparison:")
        print(f"  Blackmail (adjusted): {results['bm_mean_adjusted']:.4f}")
        print(f"  Acceptance (adjusted): {results['acc_mean_adjusted']:.4f}")
        print(f"  t = {results['t_stat_controlled']:.3f}, p = {results['p_val_controlled']:.4f}")
        if results['p_val_controlled'] < 0.05:
            print(f"  ✓ SIGNIFICANT after controlling for length")
        
        # Interpretation
        if results['p_val_raw'] < 0.05 and results['p_val_controlled'] < 0.05:
            print(f"\n  → Effect remains significant after controlling for length")
        elif results['p_val_raw'] < 0.05 and results['p_val_controlled'] >= 0.05:
            print(f"\n  → Effect disappears after controlling for length (likely confounded)")
        elif results['p_val_raw'] >= 0.05 and results['p_val_controlled'] < 0.05:
            print(f"\n  → Effect emerges after controlling for length (suppressed by length)")



In [10]:
# Usage in your main analysis loop:
shared_heads = [(6, 5), (5, 8), (14, 7), (6, 8), (22, 0), (5, 9)]

print(f"\n{'='*80}")
print("AGGREGATE ATTENTION ANALYSIS WITH LENGTH CONTROLS")
print(f"{'='*80}")

for layer, head in shared_heads:
    print(f"\n{'='*80}")
    print(f"Layer {layer}, Head {head}")
    print(f"{'='*80}")
    
    # Get aggregated stats
    bm_agg = aggregate_attention_statistics(
        blackmail_traces, layer, head, MODEL_NAME, "Blackmail"
    )
    
    acc_agg = aggregate_attention_statistics(
        acceptance_traces, layer, head, MODEL_NAME, "Acceptance"
    )
    
    if bm_agg is None or acc_agg is None:
        print("  ⚠️  Insufficient data")
        continue
    
    # Print length-controlled analysis
    print_length_controlled_analysis(bm_agg, acc_agg)


AGGREGATE ATTENTION ANALYSIS WITH LENGTH CONTROLS

Layer 6, Head 5

Processing Blackmail traces for Layer 6, Head 5...
  Processed 49/50 valid traces

Processing Acceptance traces for Layer 6, Head 5...
  Processed 45/50 valid traces

LENGTH-CONTROLLED ANALYSIS

Trace Length Distribution:
  Blackmail: 22.5 ± 20.3 sentences
  Acceptance: 56.2 ± 49.8 sentences
  Length difference: t = -4.316, p = 0.0000
  ✓ SIGNIFICANT length difference

----------------------------------------------------------------------
Metric: KURTOSIS
----------------------------------------------------------------------

Correlation with trace length:
  Blackmail: r = 0.963, p = 0.0000
    ✓ SIGNIFICANT correlation
  Acceptance: r = 1.000, p = 0.0000
    ✓ SIGNIFICANT correlation

RAW comparison (not controlling for length):
  Blackmail: 14.5522
  Acceptance: 48.9325
  t = -4.514, p = 0.0000
  ✓ SIGNIFICANT

LENGTH-CONTROLLED comparison:
  Blackmail (adjusted): -0.6789
  Acceptance (adjusted): 0.7392
  t = -1.718