# Notebook 02: Mechanistic Probes

**Research Question:** Can we detect ground truth action-taking from model activations?

This notebook:
1. Extracts activations at multiple positions and layers
2. Trains reality (tool_used) and narrative (claims_action) probes
3. **Critical:** Position analysis (first_assistant vs before_tool)
4. Analyzes probe behavior on fake action cases

**Key hypothesis:** If `first_assistant` accuracy > 80%, probe is detecting action-grounding, not just tool syntax.

## Setup

In [9]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from src.utils.logging import setup_logging
from src.config import get_config
from src.data.io import load_episodes, save_activations, load_activations
from src.extraction import extract_activations_batch
from src.analysis.probes import train_and_evaluate, analyze_probe_on_category, save_probe
from src.analysis.visualization import (
    plot_confusion_matrix,
    plot_roc_curve,
    plot_position_accuracy,
    plot_layer_analysis,
)
from src.analysis.statistics import compute_roc_auc, bootstrap_metrics

setup_logging(level="INFO")
config = get_config()

print(f"Extraction config:")
print(f"  Positions: {config.extraction.positions}")
print(f"  Layers: {config.extraction.layers}")

Extraction config:
  Positions: ['first_assistant', 'mid_response', 'before_tool']
  Layers: [0, 8, 16, 24, 31]


## 1. Load Episodes

In [10]:
from src.data.episode import EpisodeCollection

# episodes_collection = load_episodes(config.data.processed_dir / "episodes.parquet")
new_episodes = load_episodes("data/processed/new_episodes.parquet")
old_episodes = load_episodes("data/processed/episodes.parquet")

# Combine all episodes
all_episodes = EpisodeCollection(
    episodes=new_episodes.episodes + old_episodes.episodes,
    description="Combined episodes"
)

episodes = all_episodes.episodes


print(f"Loaded {len(episodes)} episodes")
print(f"\nCategory breakdown:")
summary = episodes_collection.summary()
for cat, count in summary['categories'].items():
    print(f"  {cat}: {count}")

2025-12-24 06:22:21,272 - src.data.io - INFO - Loading episodes from: data/processed/new_episodes.parquet
2025-12-24 06:22:21,334 - src.data.io - INFO - Loaded 360 episodes
2025-12-24 06:22:21,335 - src.data.io - INFO - Loading episodes from: data/processed/episodes.parquet
2025-12-24 06:22:21,390 - src.data.io - INFO - Loaded 360 episodes
Loaded 720 episodes

Category breakdown:
  honest_no_action: 80
  silent_action: 241
  true_action: 32
  fake_action: 7


## 2. Extract Activations

Extract at 3 positions × 5 layers = 15 samples per episode.

**WARNING:** This takes 1-2 hours on GPU depending on model size.

In [11]:
# Extract activations
dataset = extract_activations_batch(
    episodes=episodes,
    positions=config.extraction.positions,
    layers=config.extraction.layers,
    model_id=config.model.id,
    save_path=config.data.processed_dir / "activations.parquet",
    verbose=True,
)

print(f"\nExtracted {len(dataset)} activation samples")
print(f"Activation shape: {dataset.activations.shape}")
print(f"Hidden size: {dataset.hidden_size}")

2025-12-24 06:22:23,697 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:22:23,697 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:22:23,698 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:22:23,699 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:22:23,699 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:22:24,507 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:22:37,140 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096
2025-12-24 06:22:37,141 - src.extraction.activations - INFO - Initialized ActivationExtractor:
2025-12-24 06:22:37,142 - src.extraction.activations - INFO -   Model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:22:37,143 - src.extraction.activations - INFO -   Hidden size: 4096
2025-12-24 06:22:37,145 - src.extraction.activations - INFO -   Layers: 32
2025-12-24 06:22:37,145 - src.extraction.activations - INFO - Extracting activations from 720 episodes
2025-12-24 06:22:37,146 - src.extraction.activations - INFO -   Positions: ['first_assistant', 'mid_response', 'before_tool']
2025-12-24 06:22:37,147 - src.extraction.activations - INFO -   Layers: [0, 8, 16, 24, 31]


Extracting:   0%|          | 0/720 [00:00<?, ?it/s]

2025-12-24 06:22:37,151 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:22:37,151 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:22:37,153 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:22:37,153 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:22:37,154 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:22:37,324 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:22:49,184 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096
2025-12-24 06:22:49,214 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:22:49,215 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:22:49,215 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:22:49,216 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:22:49,216 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:22:49,363 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:23:01,085 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096


Extracting:   0%|          | 1/720 [00:24<4:50:41, 24.26s/it]

2025-12-24 06:23:01,410 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:23:01,410 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:23:01,411 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:23:01,411 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:23:01,412 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:23:01,548 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:23:13,347 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096
2025-12-24 06:23:13,355 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:23:13,355 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:23:13,356 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:23:13,357 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:23:13,358 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:23:13,548 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:23:25,083 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096


Extracting:   0%|          | 2/720 [00:48<4:48:23, 24.10s/it]

2025-12-24 06:23:25,399 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:23:25,399 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:23:25,400 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:23:25,400 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:23:25,400 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:23:25,568 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:23:36,668 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096
2025-12-24 06:23:36,676 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:23:36,676 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:23:36,676 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:23:36,677 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:23:36,677 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:23:36,833 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:23:48,095 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096


Extracting:   0%|          | 3/720 [01:11<4:42:05, 23.61s/it]

2025-12-24 06:23:48,417 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:23:48,417 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:23:48,418 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:23:48,418 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:23:48,419 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:23:48,579 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:23:59,932 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096
2025-12-24 06:23:59,938 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:23:59,939 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:23:59,939 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:23:59,939 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:23:59,940 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:24:00,066 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:24:10,976 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096


Extracting:   1%|          | 4/720 [01:34<4:38:10, 23.31s/it]

2025-12-24 06:24:11,274 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:24:11,274 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:24:11,275 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:24:11,275 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:24:11,276 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:24:11,428 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:24:22,880 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096
2025-12-24 06:24:22,888 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:24:22,890 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:24:22,891 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:24:22,892 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:24:22,893 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:24:23,273 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-12-24 06:24:34,537 - src.backends.pytorch - INFO - Model loaded. Parameters: 7,241,732,096


Extracting:   1%|          | 5/720 [01:57<4:39:04, 23.42s/it]

2025-12-24 06:24:34,886 - src.backends.pytorch - INFO - Loading model: mistralai/Mistral-7B-Instruct-v0.2
2025-12-24 06:24:34,887 - src.backends.pytorch - INFO -   Quantization: 8bit
2025-12-24 06:24:34,888 - src.backends.pytorch - INFO -   Device map: auto
2025-12-24 06:24:34,889 - src.backends.pytorch - INFO -   Dtype: float16
2025-12-24 06:24:34,889 - src.backends.pytorch - INFO -   Using 8-bit quantization
2025-12-24 06:24:35,147 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting:   1%|          | 5/720 [02:09<5:08:05, 25.85s/it]


KeyboardInterrupt: 

In [None]:
# Load activations (if already extracted)
# dataset = load_activations(config.data.processed_dir / "activations.parquet")
# print(f"Loaded {len(dataset)} activation samples")

In [None]:
# Summary
summary = dataset.summary()
print("\nDataset summary:")
print(f"  Samples: {summary['n_samples']}")
print(f"  Positions: {summary['positions']}")
print(f"  Layers: {summary['layers']}")
print(f"  Categories: {summary['categories']}")

## 3. Train Probes

Train two probes:
- **Reality probe:** Predicts `tool_used` (ground truth)
- **Narrative probe:** Predicts `claims_action` (model's claim)

In [None]:
# Filter to single position and layer for initial probe
# Use mid_response at layer 16 (middle layer)
probe_dataset = dataset.filter_by_position("mid_response").filter_by_layer(16)

print(f"Probe training dataset: {len(probe_dataset)} samples")

In [None]:
# Train reality probe
reality_probe, reality_train_metrics, reality_test_metrics = train_and_evaluate(
    probe_dataset,
    label_type="reality",
    test_size=0.2,
    random_state=42,
)

print("\nReality Probe (predicts tool_used):")
print("Train Metrics:")
print(reality_train_metrics)
print("\nTest Metrics:")
print(reality_test_metrics)

In [None]:
# Train narrative probe
narrative_probe, narrative_train_metrics, narrative_test_metrics = train_and_evaluate(
    probe_dataset,
    label_type="narrative",
    test_size=0.2,
    random_state=42,
)

print("\nNarrative Probe (predicts claims_action):")
print("Train Metrics:")
print(narrative_train_metrics)
print("\nTest Metrics:")
print(narrative_test_metrics)

In [None]:
# Save probes
save_probe(reality_probe, config.data.processed_dir / "reality_probe.pkl")
save_probe(narrative_probe, config.data.processed_dir / "narrative_probe.pkl")

print("Probes saved.")

## 4. CRITICAL: Position Analysis

**Key test:** Does the probe work at `first_assistant` (before any tool tokens)?

If yes → probe detects action-grounding, not syntax  
If no → probe might just be detecting `<<CALL` tokens

In [None]:
# Train probes at each position (using layer 16)
position_accuracies = {}

for position in config.extraction.positions:
    print(f"\nTraining probe at position: {position}")
    
    pos_dataset = dataset.filter_by_position(position).filter_by_layer(16)
    
    if len(pos_dataset) == 0:
        print(f"  No samples found for {position}")
        continue
    
    probe, _, test_metrics = train_and_evaluate(
        pos_dataset,
        label_type="reality",
        random_state=42,
    )
    
    position_accuracies[position] = test_metrics.accuracy
    print(f"  Test accuracy: {test_metrics.accuracy:.1%}")

print("\nPosition accuracies:")
for pos, acc in position_accuracies.items():
    status = "✓ PASS" if acc > 0.80 else "✗ FAIL"
    print(f"  {pos}: {acc:.1%} {status}")

In [None]:
# Visualize position analysis
fig = plot_position_accuracy(
    position_accuracies,
    title="Reality Probe Accuracy by Token Position",
    save_path=config.figures_dir / "figure2_position_accuracy",
)

plt.show()

# Check critical result
first_assistant_acc = position_accuracies.get('first_assistant', 0)
if first_assistant_acc > 0.80:
    print(f"\n✓ CRITICAL RESULT: first_assistant accuracy = {first_assistant_acc:.1%} > 80%")
    print("  → Probe detects action-grounding, not just syntax!")
else:
    print(f"\n✗ WARNING: first_assistant accuracy = {first_assistant_acc:.1%} < 80%")
    print("  → May be detecting syntax, not action-grounding")

## 5. Analyze Fake Action Cases

**Critical test:** Does the probe correctly identify fake action episodes?

In [None]:
# Analyze probe on fake action cases
fake_analysis = analyze_probe_on_category(
    reality_probe,
    probe_dataset,
    category="fake_action",
    label_type="reality",
)

print(f"\nFake Action Analysis:")
print(f"  N samples: {fake_analysis['n_samples']}")
print(f"  Metrics:")
print(fake_analysis['metrics'])

# Check alignment with ground truth
fake_probs = fake_analysis['probabilities']
fake_preds = fake_analysis['predictions']
fake_labels = fake_analysis['true_labels']  # Should all be 0 (tool not used)

# How many does probe correctly identify as "tool not used"?
correct_on_fakes = np.mean(fake_preds == 0)
mean_prob_tool_used = np.mean(fake_probs)

print(f"\n**Probe on Fake Actions:**")
print(f"  Correctly predicts 'no tool': {correct_on_fakes:.1%}")
print(f"  Mean P(tool_used): {mean_prob_tool_used:.3f}")

if correct_on_fakes > 0.95:
    print("  ✓ Probe aligns with reality, not narrative!")

In [None]:
# Histogram of P(tool_used) on fake vs true cases
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Get predictions on true action cases for comparison
true_analysis = analyze_probe_on_category(
    reality_probe,
    probe_dataset,
    category="true_action",
    label_type="reality",
)

# Plot fake action probabilities
axes[0].hist(fake_probs, bins=20, alpha=0.7, color='red', edgecolor='black')
axes[0].axvline(x=0.5, color='k', linestyle='--', linewidth=1)
axes[0].set_xlabel('P(tool_used)')
axes[0].set_ylabel('Count')
axes[0].set_title(f'Fake Actions (should cluster near 0)\nMean = {mean_prob_tool_used:.3f}')
axes[0].set_xlim(0, 1)

# Plot true action probabilities
true_probs = true_analysis['probabilities']
mean_prob_true = np.mean(true_probs)
axes[1].hist(true_probs, bins=20, alpha=0.7, color='green', edgecolor='black')
axes[1].axvline(x=0.5, color='k', linestyle='--', linewidth=1)
axes[1].set_xlabel('P(tool_used)')
axes[1].set_ylabel('Count')
axes[1].set_title(f'True Actions (should cluster near 1)\nMean = {mean_prob_true:.3f}')
axes[1].set_xlim(0, 1)

plt.tight_layout()
plt.savefig(config.figures_dir / "figure3_fake_vs_true_probs.png", dpi=300, bbox_inches='tight')
plt.show()

## 6. Layer Analysis

Which layers encode action-grounding information?

In [None]:
# Train probes at each layer (using mid_response position)
layer_accuracies = {}

for layer in config.extraction.layers:
    print(f"\nTraining probe at layer: {layer}")
    
    layer_dataset = dataset.filter_by_position("mid_response").filter_by_layer(layer)
    
    if len(layer_dataset) == 0:
        print(f"  No samples found for layer {layer}")
        continue
    
    probe, _, test_metrics = train_and_evaluate(
        layer_dataset,
        label_type="reality",
        random_state=42,
    )
    
    layer_accuracies[layer] = test_metrics.accuracy
    print(f"  Test accuracy: {test_metrics.accuracy:.1%}")

In [None]:
# Visualize layer analysis
fig = plot_layer_analysis(
    layer_accuracies,
    title="Reality Probe Accuracy by Layer",
    save_path=config.figures_dir / "figure5_layer_accuracy",
)

plt.show()

# Best layer
best_layer = max(layer_accuracies.items(), key=lambda x: x[1])
print(f"\nBest layer: {best_layer[0]} (accuracy: {best_layer[1]:.1%})")

## 7. Probe Direction Analysis

Are reality and narrative probes learning the same representation?

In [None]:
# Compare probe directions
from src.analysis.probes import compare_probes

comparison = compare_probes(reality_probe, narrative_probe, normalize=True)

print(f"\nReality vs Narrative Probe:")
print(f"  Cosine similarity: {comparison['cosine_similarity']:.3f}")
print(f"  L2 distance: {comparison['l2_distance']:.3f}")

if abs(comparison['cosine_similarity']) > 0.8:
    print("  → Probes learn similar directions (aligned)")
elif abs(comparison['cosine_similarity']) < 0.3:
    print("  → Probes learn different directions (independent representations)")
else:
    print("  → Probes partially aligned")

## 8. Visualizations

In [None]:
# Get test data for reality probe
train_dataset, test_dataset = probe_dataset.train_test_split(test_size=0.2, random_state=42)
X_test, y_test = test_dataset.to_sklearn_format("reality")

# Confusion matrix
y_pred = reality_probe.predict(X_test)
fig = plot_confusion_matrix(
    y_test,
    y_pred,
    labels=["No Tool", "Tool Used"],
    title="Reality Probe Confusion Matrix",
    save_path=config.figures_dir / "reality_probe_confusion",
)
plt.show()

In [None]:
# ROC curve
y_proba = reality_probe.predict_proba(X_test)[:, 1]
auc, fpr, tpr, thresholds = compute_roc_auc(y_test, y_proba)

fig = plot_roc_curve(
    fpr,
    tpr,
    auc,
    title="Reality Probe ROC Curve",
    save_path=config.figures_dir / "reality_probe_roc",
)
plt.show()

## Summary

In [None]:
print("=" * 60)
print("PHASE 2 RESULTS: MECHANISTIC PROBES")
print("=" * 60)

print(f"\nReality Probe Performance:")
print(f"  Test Accuracy: {reality_test_metrics.accuracy:.1%}")
print(f"  ROC-AUC: {reality_test_metrics.roc_auc:.3f}")

print(f"\nPosition Analysis:")
for pos, acc in position_accuracies.items():
    print(f"  {pos}: {acc:.1%}")

print(f"\nFake Action Analysis:")
print(f"  Correct on fakes: {correct_on_fakes:.1%}")
print(f"  Mean P(tool_used) on fakes: {mean_prob_tool_used:.3f}")

print(f"\nProbe Direction Comparison:")
print(f"  Cosine similarity: {comparison['cosine_similarity']:.3f}")

print("\n✓ Phase 2 complete: Linear probe can detect ground truth")
print("=" * 60)

## Next Steps

→ **Notebook 03:** Test cross-tool generalization