[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nawidayima/IPHR_Direction/blob/main/notebooks/04_train_probes.ipynb)

# Train Linear Probes for Rationalization Detection

**Goal:** Find a direction in activation space that separates "rationalization" from "honest" reasoning.

**Project Plan Reference:** Phase 2, Hours 8-14

**Key methods:**
1. **Difference-in-Means (DiM):** Simple baseline - find the direction between class centroids
2. **Logistic Regression:** Learn the optimal separating direction with regularization

**Success criteria:**
- ROC-AUC > 0.7 = weak signal
- ROC-AUC > 0.8 = good signal (target)

**Setup:** Run Cell 0 once to install dependencies, then restart runtime and run from Cell 1.

In [None]:
# Cell 0: Setup - Clone repo and install dependencies
# NOTE: After running this cell, RESTART RUNTIME (Runtime > Restart runtime)
#       Then skip this cell and run from Cell 1 onwards

import os

# Clone repo (only if not already cloned)
if not os.path.exists('/content/IPHR_Direction'):
    !git clone https://github.com/nawidayima/IPHR_Direction.git
    %cd /content/IPHR_Direction
else:
    %cd /content/IPHR_Direction
    !git pull  # Get latest changes

# Install dependencies
!pip install torch numpy pandas scikit-learn matplotlib -q

# Install package in editable mode
!pip install -e . -q

print("="*60)
print("Setup complete! Restart runtime and run from Cell 1.")
print("="*60)

In [None]:
# Cell 1: Imports
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime

# Scikit-learn for ML
from sklearn.model_selection import GroupShuffleSplit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.decomposition import PCA

# Visualization
import matplotlib.pyplot as plt

print("Imports complete!")

In [None]:
# Cell 2: Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# For this notebook, we mainly use CPU since probe training is fast
# The heavy lifting (activation extraction) was done in notebook 03

## Load Activations

We load the pre-extracted activations from notebook 03. The data structure is:

```
{
    'activations': {layer_idx: Tensor[n_samples, d_model]},
    'labels': Tensor[n_samples],  # 1=contradiction, 0=honest
    'metadata': [{'pair_id', 'domain', 'question_type'}, ...]
}
```

Each question pair contributes 2 samples (question A and question B), both with the same label.

In [None]:
# Cell 3: Load activations
# Ensure we're in the right directory
%cd /content/IPHR_Direction

RUN_DIR = Path("experiments/run_20251228_204835_expand_dataset")
ACTIVATIONS_PATH = RUN_DIR / "activations/residual_stream_activations.pt"

data = torch.load(ACTIVATIONS_PATH)

print("Loaded activation data:")
print(f"  Model: {data['model_name']}")
print(f"  Layers: {data['layers']}")
print(f"  d_model: {data['d_model']}")
print(f"  n_pairs: {data['n_pairs']}")
print(f"  n_samples: {data['n_samples']}")
print(f"\nActivation shapes:")
for layer, acts in data['activations'].items():
    print(f"  Layer {layer}: {acts.shape}")
print(f"\nLabels: {data['labels'].shape}")
print(f"  Contradictions (1): {data['labels'].sum().item()}")
print(f"  Honest (0): {(~data['labels'].bool()).sum().item()}")

### Methodological Choices (per project-plan.md, Hours 6-8)

**Token position:** Last token of prompt (decision point before generation)
- This is the moment the model "commits" to its reasoning strategy
- Matches Arditi et al.'s most common choice (Table 5: `i* = -1` for 10/13 models)

**Layers:** Upper third [24, 28, 31] of 32 total
- High-level semantic features emerge in later layers
- Consistent with Arditi et al. findings: directions typically found at `l*/L ≈ 0.5-0.8`

**Rationale:** Start simple. If ROC-AUC < 0.7, *then* sweep positions/layers (Hour 12-14 contingency).

In [None]:
# Cell 4: Display label distribution by domain
metadata = data['metadata']
labels = data['labels'].numpy()

# Count by domain and label
domain_counts = {}
for m, label in zip(metadata, labels):
    domain = m['domain']
    if domain not in domain_counts:
        domain_counts[domain] = {'contradiction': 0, 'honest': 0}
    if label == 1:
        domain_counts[domain]['contradiction'] += 1
    else:
        domain_counts[domain]['honest'] += 1

print("Distribution by domain:")
print("-" * 40)
for domain, counts in sorted(domain_counts.items()):
    total = counts['contradiction'] + counts['honest']
    print(f"{domain:12s}: {counts['contradiction']:3d} contradiction, {counts['honest']:3d} honest (total: {total})")

## Train/Test Split

### Why split by pair_id?

Each question pair generates 2 samples (A and B) with the **same label** and **correlated activations**. If we split randomly:
- Question A might end up in train
- Question B (same pair) might end up in test
- The probe could "cheat" by memorizing pair-specific patterns

**Solution:** Split by `pair_id` so both A and B from the same pair stay together.

```
80% of pairs → train (both A and B samples)
20% of pairs → test (both A and B samples)
```

This gives a more honest estimate of generalization.

In [None]:
# Cell 5: Train/test split by pair_id

# Extract pair_ids and domains for stratification
pair_ids = np.array([m['pair_id'] for m in metadata])
domains = np.array([m['domain'] for m in metadata])

# Get unique pairs and their properties
unique_pairs = []
pair_to_indices = {}
pair_to_label = {}
pair_to_domain = {}

for i, (pid, domain, label) in enumerate(zip(pair_ids, domains, labels)):
    if pid not in pair_to_indices:
        unique_pairs.append(pid)
        pair_to_indices[pid] = []
        pair_to_label[pid] = label
        pair_to_domain[pid] = domain
    pair_to_indices[pid].append(i)

print(f"Total unique pairs: {len(unique_pairs)}")

# Create arrays for splitting
unique_pairs = np.array(unique_pairs)
pair_labels = np.array([pair_to_label[pid] for pid in unique_pairs])
pair_domains = np.array([pair_to_domain[pid] for pid in unique_pairs])

# Use GroupShuffleSplit to split by pair while stratifying by label
# We use pair_id as both the index and group
splitter = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

# Split based on pairs
train_pair_idx, test_pair_idx = next(splitter.split(unique_pairs, pair_labels, groups=unique_pairs))

train_pairs = set(unique_pairs[train_pair_idx])
test_pairs = set(unique_pairs[test_pair_idx])

# Map back to sample indices
train_indices = []
test_indices = []

for pid in unique_pairs:
    if pid in train_pairs:
        train_indices.extend(pair_to_indices[pid])
    else:
        test_indices.extend(pair_to_indices[pid])

train_indices = np.array(train_indices)
test_indices = np.array(test_indices)

print(f"\nSplit results:")
print(f"  Train pairs: {len(train_pairs)} → {len(train_indices)} samples")
print(f"  Test pairs: {len(test_pairs)} → {len(test_indices)} samples")
print(f"\nTrain label distribution:")
print(f"  Contradiction: {labels[train_indices].sum()}")
print(f"  Honest: {(1 - labels[train_indices]).sum()}")
print(f"\nTest label distribution:")
print(f"  Contradiction: {labels[test_indices].sum()}")
print(f"  Honest: {(1 - labels[test_indices]).sum()}")

## Difference-in-Means (DiM) Direction

### Mathematical Foundation

The simplest way to find a separating direction is **difference-in-means**:

$$\vec{v}_{\text{rationalization}} = \text{mean}(X_{\text{contradiction}}) - \text{mean}(X_{\text{honest}})$$

**Intuition:** This vector points from the "centroid" of honest activations toward the "centroid" of contradiction activations in 4096-dimensional space.

### How to classify with DiM

To score a new activation $\vec{x}$, we compute the **dot product** (projection):

$$\text{score} = \vec{x} \cdot \vec{v}_{\text{rationalization}}$$

**What the dot product means:**
- It measures "how much of $\vec{x}$ lies in the direction of $\vec{v}$"
- Higher score → activation is more "contradiction-like"
- Lower score → activation is more "honest-like"

**Why this works:** If the two classes form separate clusters, the line connecting their centers is a natural separating direction. Points projected onto this line will cluster by class.

In [None]:
# Cell 6: Compute DiM direction for each layer

# Prepare train/test data
train_labels = labels[train_indices]
test_labels = labels[test_indices]

# Store results
dim_directions = {}  # layer -> direction vector
dim_scores_test = {}  # layer -> scores on test set

print("Computing Difference-in-Means directions...\n")

for layer in data['layers']:
    # Get activations for this layer
    acts = data['activations'][layer].numpy()
    
    # Split into train/test
    train_acts = acts[train_indices]
    test_acts = acts[test_indices]
    
    # Compute class means on TRAINING data only
    train_contradiction_mask = train_labels == 1
    train_honest_mask = train_labels == 0
    
    mean_contradiction = train_acts[train_contradiction_mask].mean(axis=0)
    mean_honest = train_acts[train_honest_mask].mean(axis=0)
    
    # DiM direction: points from honest toward contradiction
    dim_direction = mean_contradiction - mean_honest
    
    # Normalize to unit vector (optional but helps interpretation)
    dim_direction_norm = np.linalg.norm(dim_direction)
    dim_direction_unit = dim_direction / dim_direction_norm
    
    dim_directions[layer] = dim_direction_unit
    
    # Score test samples by projecting onto DiM direction
    # Higher score = more contradiction-like
    test_scores = test_acts @ dim_direction_unit
    dim_scores_test[layer] = test_scores
    
    print(f"Layer {layer}:")
    print(f"  DiM direction norm (before normalizing): {dim_direction_norm:.4f}")
    print(f"  Mean score (contradiction): {test_scores[test_labels == 1].mean():.4f}")
    print(f"  Mean score (honest): {test_scores[test_labels == 0].mean():.4f}")
    print()

## ROC-AUC Evaluation

### What is ROC-AUC?

**ROC** = Receiver Operating Characteristic curve  
**AUC** = Area Under the Curve

ROC-AUC measures how well our scores **rank** samples by class, regardless of threshold.

### Interpretation

| ROC-AUC | Meaning |
|---------|----------|
| 0.5 | Random guessing (no signal) |
| 0.7 | Weak signal |
| 0.8 | Good signal (our target) |
| 1.0 | Perfect separation |

### Intuitive explanation

ROC-AUC answers: **"If I pick one contradiction sample and one honest sample at random, how often does the probe correctly rank the contradiction higher?"**

- AUC = 0.5 → 50% of the time (coin flip)
- AUC = 0.8 → 80% of the time
- AUC = 1.0 → 100% of the time

### Why ROC-AUC instead of accuracy?

Accuracy requires choosing a threshold. ROC-AUC is **threshold-free** - it evaluates the entire ranking. This is better when:
- We don't know the optimal threshold yet
- We want to compare methods fairly
- The class distribution might differ at deployment time

### Detection vs. Intervention: Why We Don't Need Arditi's Full Sweep

**Arditi et al.'s approach:** Exhaustive search over all layers × all post-instruction token positions, selecting the best direction via validation metrics (`bypass_score`, `induce_score`, `kl_score`).

**Why that level of rigor?** Their goal is *intervention*—ablating a direction must reliably disable refusal. The optimal direction matters for causal effect.

**Our goal is *detection*** (classification). A strong signal (AUC > 0.8) at *any* reasonable position is sufficient evidence for H1. We don't need the globally optimal direction, just one that separates the classes.

**Decision rule:**
- AUC ≥ 0.8 → Proceed to confounder check (Hour 10-12)
- 0.7 ≤ AUC < 0.8 → Try logistic regression (this notebook)
- AUC < 0.7 → Consider position sweep (first generated token) or layer sweep

In [None]:
# Cell 7: Compute ROC-AUC for DiM on each layer

print("ROC-AUC for Difference-in-Means Direction")
print("=" * 45)
print()

dim_aucs = {}

for layer in data['layers']:
    scores = dim_scores_test[layer]
    auc = roc_auc_score(test_labels, scores)
    dim_aucs[layer] = auc
    
    # Interpret the result
    if auc >= 0.8:
        status = "GOOD SIGNAL"
    elif auc >= 0.7:
        status = "WEAK SIGNAL"
    elif auc >= 0.5:
        status = "minimal signal"
    else:
        status = "inverted (flip direction)"
    
    print(f"Layer {layer}: ROC-AUC = {auc:.4f}  [{status}]")

print()
best_layer = max(dim_aucs, key=dim_aucs.get)
print(f"Best layer: {best_layer} (AUC = {dim_aucs[best_layer]:.4f})")

In [None]:
# Cell 8: Plot ROC curves for all layers

fig, axes = plt.subplots(1, len(data['layers']), figsize=(4 * len(data['layers']), 4))

for ax, layer in zip(axes, data['layers']):
    scores = dim_scores_test[layer]
    fpr, tpr, thresholds = roc_curve(test_labels, scores)
    auc = dim_aucs[layer]
    
    ax.plot(fpr, tpr, 'b-', linewidth=2, label=f'DiM (AUC={auc:.3f})')
    ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random (AUC=0.5)')
    ax.fill_between(fpr, tpr, alpha=0.2)
    
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title(f'Layer {layer}')
    ax.legend(loc='lower right')
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

plt.suptitle('ROC Curves: Difference-in-Means Probe', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## Logistic Regression Probe

### Why logistic regression?

Difference-in-Means finds **a** separating direction (the line between centroids), but not necessarily the **optimal** one.

Logistic regression learns the direction that maximizes classification performance:

$$P(\text{contradiction} | \vec{x}) = \sigma(\vec{w} \cdot \vec{x} + b)$$

where $\sigma(z) = \frac{1}{1 + e^{-z}}$ is the sigmoid function.

### The overfitting problem

Our activations are 4096-dimensional, but we only have ~150 training samples. This is a classic "p >> n" problem where overfitting is a serious risk.

### L2 Regularization

We add a penalty term that shrinks weights toward zero:

$$\text{Loss} = \text{CrossEntropy} + \lambda ||\vec{w}||^2$$

- **High $\lambda$** → Stronger regularization → Weights stay small → More like DiM direction
- **Low $\lambda$** → Weaker regularization → Weights can grow → Risk of overfitting

In sklearn, `C = 1/λ`, so:
- **Small C** (e.g., 0.01) → Strong regularization
- **Large C** (e.g., 100) → Weak regularization

In [None]:
# Cell 9: Train logistic regression probe for each layer

# Use strong regularization (C=0.1) to prevent overfitting on high-dimensional data
# This is especially important when n_samples << n_features
C_VALUE = 0.1  # Regularization strength (smaller = stronger regularization)

lr_probes = {}  # layer -> trained LogisticRegression model
lr_aucs = {}    # layer -> ROC-AUC on test set

print(f"Training Logistic Regression probes (C={C_VALUE})...\n")

for layer in data['layers']:
    # Get activations
    acts = data['activations'][layer].numpy()
    train_acts = acts[train_indices]
    test_acts = acts[test_indices]
    
    # Train logistic regression
    lr = LogisticRegression(
        C=C_VALUE,
        penalty='l2',
        solver='lbfgs',
        max_iter=1000,
        random_state=42,
    )
    lr.fit(train_acts, train_labels)
    lr_probes[layer] = lr
    
    # Predict probabilities on test set
    test_probs = lr.predict_proba(test_acts)[:, 1]  # P(contradiction)
    
    # Compute ROC-AUC
    auc = roc_auc_score(test_labels, test_probs)
    lr_aucs[layer] = auc
    
    # Compare DiM direction with LR weights
    dim_dir = dim_directions[layer]
    lr_dir = lr.coef_[0] / np.linalg.norm(lr.coef_[0])
    cosine_sim = np.dot(dim_dir, lr_dir)
    
    print(f"Layer {layer}:")
    print(f"  LR ROC-AUC: {auc:.4f} (DiM was {dim_aucs[layer]:.4f})")
    print(f"  Cosine similarity (DiM vs LR direction): {cosine_sim:.4f}")
    print()

print("=" * 50)
best_layer_lr = max(lr_aucs, key=lr_aucs.get)
print(f"Best layer (LR): {best_layer_lr} (AUC = {lr_aucs[best_layer_lr]:.4f})")

In [None]:
# Cell 10: Compare DiM vs Logistic Regression

print("Comparison: Difference-in-Means vs Logistic Regression")
print("=" * 55)
print(f"{'Layer':<10} {'DiM AUC':<12} {'LR AUC':<12} {'Improvement':<12}")
print("-" * 55)

for layer in data['layers']:
    dim_auc = dim_aucs[layer]
    lr_auc = lr_aucs[layer]
    improvement = lr_auc - dim_auc
    
    print(f"{layer:<10} {dim_auc:<12.4f} {lr_auc:<12.4f} {improvement:+.4f}")

print()
print("Note: If improvement is small, DiM direction is already near-optimal.")
print("Large improvement suggests the optimal direction differs from class centroids.")

## PCA Visualization

### What is PCA?

**Principal Component Analysis** projects high-dimensional data (4096D) to low dimensions (2D) while preserving as much variance as possible.

The first principal component (PC1) is the direction of maximum variance.  
The second (PC2) is perpendicular to PC1 and captures the next most variance.

### Important caveat

PCA is for **visualization only**, not analysis!

If classes overlap in 2D PCA, they may still be perfectly separable in the full 4096D space. The separating direction might be orthogonal to the high-variance directions that PCA finds.

**Rule of thumb:**
- Classes separate in PCA → Good (probe should work well)
- Classes overlap in PCA → Inconclusive (probe might still work)

In [None]:
# Cell 11: PCA visualization for best layer

# Use the layer with best DiM AUC
viz_layer = best_layer
print(f"Visualizing layer {viz_layer} (best DiM AUC = {dim_aucs[viz_layer]:.4f})\n")

# Get all activations for this layer
acts = data['activations'][viz_layer].numpy()

# Fit PCA on all data
pca = PCA(n_components=2, random_state=42)
acts_2d = pca.fit_transform(acts)

print(f"Variance explained by PC1: {pca.explained_variance_ratio_[0]*100:.1f}%")
print(f"Variance explained by PC2: {pca.explained_variance_ratio_[1]*100:.1f}%")
print(f"Total variance explained: {pca.explained_variance_ratio_.sum()*100:.1f}%")

# Plot
fig, ax = plt.subplots(figsize=(10, 8))

# Separate by label
contradiction_mask = labels == 1
honest_mask = labels == 0

ax.scatter(
    acts_2d[honest_mask, 0], acts_2d[honest_mask, 1],
    c='blue', alpha=0.6, label='Honest', s=50, edgecolors='white', linewidth=0.5
)
ax.scatter(
    acts_2d[contradiction_mask, 0], acts_2d[contradiction_mask, 1],
    c='red', alpha=0.6, label='Contradiction', s=50, edgecolors='white', linewidth=0.5
)

# Mark test set with different markers
test_mask = np.zeros(len(labels), dtype=bool)
test_mask[test_indices] = True

ax.scatter(
    acts_2d[test_mask & honest_mask, 0], acts_2d[test_mask & honest_mask, 1],
    c='blue', marker='s', s=80, edgecolors='black', linewidth=1.5, label='Honest (test)'
)
ax.scatter(
    acts_2d[test_mask & contradiction_mask, 0], acts_2d[test_mask & contradiction_mask, 1],
    c='red', marker='s', s=80, edgecolors='black', linewidth=1.5, label='Contradiction (test)'
)

ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)')
ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)')
ax.set_title(f'PCA of Layer {viz_layer} Activations\n(Circles=train, Squares=test)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Cell 12: Visualize DiM direction in PCA space

# Project the DiM direction onto PCA
dim_dir = dim_directions[viz_layer]
dim_dir_2d = pca.transform(dim_dir.reshape(1, -1))[0]

# Also get class centroids
mean_honest = acts[honest_mask].mean(axis=0)
mean_contradiction = acts[contradiction_mask].mean(axis=0)
mean_honest_2d = pca.transform(mean_honest.reshape(1, -1))[0]
mean_contradiction_2d = pca.transform(mean_contradiction.reshape(1, -1))[0]

fig, ax = plt.subplots(figsize=(10, 8))

# Plot data points (faded)
ax.scatter(
    acts_2d[honest_mask, 0], acts_2d[honest_mask, 1],
    c='blue', alpha=0.3, s=30, label='Honest'
)
ax.scatter(
    acts_2d[contradiction_mask, 0], acts_2d[contradiction_mask, 1],
    c='red', alpha=0.3, s=30, label='Contradiction'
)

# Plot centroids
ax.scatter(
    mean_honest_2d[0], mean_honest_2d[1],
    c='blue', s=200, marker='*', edgecolors='black', linewidth=2,
    label='Honest centroid', zorder=5
)
ax.scatter(
    mean_contradiction_2d[0], mean_contradiction_2d[1],
    c='red', s=200, marker='*', edgecolors='black', linewidth=2,
    label='Contradiction centroid', zorder=5
)

# Draw arrow for DiM direction (from honest to contradiction centroid)
ax.annotate(
    '', xy=mean_contradiction_2d, xytext=mean_honest_2d,
    arrowprops=dict(arrowstyle='->', color='green', lw=3),
)
ax.text(
    (mean_honest_2d[0] + mean_contradiction_2d[0]) / 2,
    (mean_honest_2d[1] + mean_contradiction_2d[1]) / 2 + 0.5,
    'DiM direction', fontsize=12, color='green', ha='center'
)

ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)')
ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)')
ax.set_title(f'Layer {viz_layer}: DiM Direction (centroid to centroid)')
ax.legend(loc='best')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Save Results

Save the trained probes and evaluation metrics for later use (steering experiments in Phase 3).

## Summary

Probe training complete! Results saved to:
```
experiments/run_20251228_204835_expand_dataset/probes/rationalization_probes.pt
```

**Contents:**
- `dim_directions`: DiM direction vectors (unit normalized) per layer
- `dim_aucs`: ROC-AUC scores for DiM method
- `lr_weights`, `lr_biases`: Logistic regression probe weights
- `lr_aucs`: ROC-AUC scores for LR method

---

**Next steps (per project-plan.md):**

1. **Hours 10-12: Confounder Check**
   - Does the probe fire on incorrect but non-contradictory answers?
   - If yes → detecting "wrongness", not rationalization (confounded)
   - If no → direction is specific to rationalization

2. **Hours 14-16: H2 Generalization Test (ICRL)**
   - Generate ICRL trajectories (correct answer → fake negative feedback)
   - Test if the same probe detects ICRL-induced rationalization

---

**Deferred to Phase 3 (Hours 16-18): Steering Intervention**

Once detection is validated, test *causality* via directional ablation:

$$\mathbf{h}' = \mathbf{h} - (\mathbf{h} \cdot \hat{\mathbf{v}}) \hat{\mathbf{v}}$$

where $\hat{\mathbf{v}}$ is the unit rationalization direction (`dim_directions[layer]`).

*Expected result:* Ablating the rationalization direction reduces contradiction rate on held-out pairs.

*Reference:* Arditi et al. Eq. 4 (directional ablation), Eq. 5 (weight orthogonalization—optional for permanent edit).

In [None]:
# Cell 14: Validation - load and verify

loaded = torch.load(save_path)

print("Validation - loaded probe data:")
print(f"  Model: {loaded['model_name']}")
print(f"  Layers: {loaded['layers']}")
print(f"  Best layer (DiM): {loaded['best_layer_dim']} (AUC={loaded['dim_aucs'][loaded['best_layer_dim']]:.4f})")
print(f"  Best layer (LR): {loaded['best_layer_lr']} (AUC={loaded['lr_aucs'][loaded['best_layer_lr']]:.4f})")
print(f"  Train/test split: {loaded['n_train']}/{loaded['n_test']}")
print(f"\nDiM direction shapes:")
for layer, dir in loaded['dim_directions'].items():
    print(f"  Layer {layer}: {dir.shape}")

## Summary

Probe training complete! Results saved to:
```
experiments/run_20251228_204835_expand_dataset/probes/rationalization_probes.pt
```

**Contents:**
- `dim_directions`: DiM direction vectors (unit normalized) per layer
- `dim_aucs`: ROC-AUC scores for DiM method
- `lr_weights`, `lr_biases`: Logistic regression probe weights
- `lr_aucs`: ROC-AUC scores for LR method

**Next steps (per project-plan.md):**

1. **Hours 10-12: Confounder Check**
   - Does the probe fire on incorrect but non-contradictory answers?
   - If yes → detecting "wrongness", not rationalization (confounded)
   - If no → direction is specific to rationalization

2. **Hours 14-16: H2 Generalization Test (ICRL)**
   - Generate ICRL trajectories (correct answer → fake negative feedback)
   - Test if the same probe detects ICRL-induced rationalization

3. **Hours 16-18: Steering Intervention**
   - Ablate the rationalization direction during generation
   - Test if contradiction rate decreases

In [None]:
# Cell 15: (Optional) Push to GitHub
# Uncomment to save probes to repo

# !git add experiments/
# !git commit -m "Add trained probes from Phase 2"
# !git push