[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nawidayima/IPHR_Direction/blob/main/notebooks/05_probe_geography_only.ipynb)

# Geography-Only Probe (Question B Focus)

**Goal:** Train a cleaner probe by focusing on:
1. **Geography domain only** - removes cross-domain confounds
2. **Question B only** - the "rationalization" moment where the model should give the opposite answer

**Why Question B?**
- Question A: "Is Paris south of Cairo?" → Model reasons and answers
- Question B: "Is Cairo south of Paris?" → Model should give opposite answer
- In contradiction cases, the model says NO to both, "rationalizing" a wrong answer for B

**Evaluation:** 5-fold cross-validation (small sample size requires robust estimation)

In [None]:
# Cell 0: Setup - Clone repo and install dependencies
# NOTE: After running this cell, RESTART RUNTIME then run from Cell 1

import os

if not os.path.exists('/content/IPHR_Direction'):
    !git clone https://github.com/nawidayima/IPHR_Direction.git
    %cd /content/IPHR_Direction
else:
    %cd /content/IPHR_Direction
    !git pull

!pip install torch numpy pandas scikit-learn matplotlib -q
!pip install -e . -q

print("Setup complete! Restart runtime and run from Cell 1.")

In [None]:
# Cell 1: Imports
import torch
import numpy as np
import pandas as pd
from pathlib import Path

from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt

print("Imports complete!")

In [None]:
# Cell 2: Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Note: Probe training uses CPU (fast enough for small data)")

## Load and Filter Data

We'll load the full activation dataset, then filter to:
- Geography domain only
- Question B only (where rationalization occurs)

In [None]:
# Cell 3: Load activations
%cd /content/IPHR_Direction

RUN_DIR = Path("experiments/run_20251228_204835_expand_dataset")
ACTIVATIONS_PATH = RUN_DIR / "activations/residual_stream_activations.pt"

data = torch.load(ACTIVATIONS_PATH)

print("Full dataset:")
print(f"  Total samples: {data['n_samples']}")
print(f"  Layers: {data['layers']}")
print(f"  d_model: {data['d_model']}")

In [None]:
# Cell 4: Filter to Geography, Question B only
metadata = data['metadata']
labels = data['labels'].numpy()

# Find indices for geography domain, question B only
geo_b_indices = []
for i, m in enumerate(metadata):
    if m['domain'] == 'geography' and m['question_type'] == 'B':
        geo_b_indices.append(i)

geo_b_indices = np.array(geo_b_indices)

print(f"Filtered to Geography, Question B:")
print(f"  Total samples: {len(geo_b_indices)}")

# Get filtered labels
filtered_labels = labels[geo_b_indices]
n_contradiction = filtered_labels.sum()
n_honest = len(filtered_labels) - n_contradiction

print(f"  Contradictions: {n_contradiction}")
print(f"  Honest: {n_honest}")
print(f"  Balance: {n_contradiction/len(filtered_labels)*100:.1f}% contradiction")

In [None]:
# Cell 5: Extract filtered activations per layer
filtered_activations = {}
for layer in data['layers']:
    acts = data['activations'][layer].numpy()
    filtered_activations[layer] = acts[geo_b_indices]
    print(f"Layer {layer}: {filtered_activations[layer].shape}")

print(f"\nLabels: {filtered_labels.shape}")

## Cross-Validation Setup

With only ~50 samples, a single train/test split would give ~10 test samples - too noisy for reliable AUC estimation.

**5-fold cross-validation:**
- Split data into 5 equal parts
- Train on 4 parts, test on 1 (rotating)
- Each sample is tested exactly once
- Report: mean AUC ± standard deviation across folds

This gives us a more stable estimate of probe performance.

In [None]:
# Cell 6: Set up cross-validation
N_FOLDS = 5
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

print(f"Using {N_FOLDS}-fold stratified cross-validation")
print(f"Each fold: ~{len(filtered_labels)//N_FOLDS * (N_FOLDS-1)} train, ~{len(filtered_labels)//N_FOLDS} test")

## Difference-in-Means with Cross-Validation

For each fold:
1. Compute DiM direction on training data: `direction = mean(contradiction) - mean(honest)`
2. Project test data onto this direction
3. Compute ROC-AUC on test data

Report mean ± std across all folds.

In [None]:
# Cell 7: Difference-in-Means cross-validation

def compute_dim_direction(X, y):
    """Compute difference-in-means direction.
    
    Returns unit vector pointing from honest (y=0) toward contradiction (y=1).
    """
    mean_contradiction = X[y == 1].mean(axis=0)
    mean_honest = X[y == 0].mean(axis=0)
    direction = mean_contradiction - mean_honest
    # Normalize to unit vector
    direction = direction / np.linalg.norm(direction)
    return direction


def dim_cross_val(X, y, cv):
    """Cross-validated ROC-AUC for difference-in-means probe."""
    aucs = []
    for train_idx, test_idx in cv.split(X, y):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        # Compute direction on training data
        direction = compute_dim_direction(X_train, y_train)
        
        # Score test data
        scores = X_test @ direction
        
        # Compute AUC
        auc = roc_auc_score(y_test, scores)
        aucs.append(auc)
    
    return np.array(aucs)


# Run for each layer
print("Difference-in-Means Cross-Validation Results")
print("=" * 50)
print()

dim_results = {}
for layer in data['layers']:
    X = filtered_activations[layer]
    y = filtered_labels
    
    aucs = dim_cross_val(X, y, skf)
    dim_results[layer] = aucs
    
    mean_auc = aucs.mean()
    std_auc = aucs.std()
    
    # Interpret
    if mean_auc >= 0.8:
        status = "GOOD SIGNAL"
    elif mean_auc >= 0.7:
        status = "WEAK SIGNAL"
    elif mean_auc >= 0.55:
        status = "marginal"
    else:
        status = "no signal"
    
    print(f"Layer {layer}: AUC = {mean_auc:.3f} ± {std_auc:.3f}  [{status}]")
    print(f"           Fold AUCs: {aucs.round(3)}")
    print()

In [None]:
# Cell 8: Visualize DiM results across folds
fig, ax = plt.subplots(figsize=(10, 5))

layers = list(dim_results.keys())
x_pos = np.arange(len(layers))

means = [dim_results[l].mean() for l in layers]
stds = [dim_results[l].std() for l in layers]

bars = ax.bar(x_pos, means, yerr=stds, capsize=5, color='steelblue', alpha=0.7)
ax.axhline(y=0.5, color='red', linestyle='--', label='Random (0.5)')
ax.axhline(y=0.7, color='orange', linestyle='--', label='Weak signal (0.7)')
ax.axhline(y=0.8, color='green', linestyle='--', label='Good signal (0.8)')

ax.set_xlabel('Layer')
ax.set_ylabel('ROC-AUC')
ax.set_title('Difference-in-Means: Geography Question B Only\n(5-fold CV, error bars = std)')
ax.set_xticks(x_pos)
ax.set_xticklabels([f'Layer {l}' for l in layers])
ax.set_ylim(0, 1)
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)

# Add value labels on bars
for i, (m, s) in enumerate(zip(means, stds)):
    ax.text(i, m + s + 0.02, f'{m:.3f}', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

## Logistic Regression with Cross-Validation

Logistic regression learns the optimal direction for classification, potentially better than DiM.

We try multiple regularization strengths:
- **C = 0.01**: Very strong regularization (direction close to DiM)
- **C = 0.1**: Strong regularization
- **C = 1.0**: Moderate regularization (sklearn default)
- **C = 10.0**: Weak regularization (more flexible, risk of overfitting)

In [None]:
# Cell 9: Logistic Regression cross-validation with multiple C values

C_VALUES = [0.01, 0.1, 1.0, 10.0]

print("Logistic Regression Cross-Validation Results")
print("=" * 60)
print()

lr_results = {}  # layer -> {C -> aucs}

for layer in data['layers']:
    X = filtered_activations[layer]
    y = filtered_labels
    
    lr_results[layer] = {}
    
    print(f"Layer {layer}:")
    for C in C_VALUES:
        lr = LogisticRegression(
            C=C,
            penalty='l2',
            solver='lbfgs',
            max_iter=1000,
            random_state=42
        )
        
        # Cross-validation
        aucs = cross_val_score(lr, X, y, cv=skf, scoring='roc_auc')
        lr_results[layer][C] = aucs
        
        print(f"  C={C:5.2f}: AUC = {aucs.mean():.3f} ± {aucs.std():.3f}")
    
    print()

In [None]:
# Cell 10: Compare DiM vs best LR for each layer

print("Comparison: Difference-in-Means vs Best Logistic Regression")
print("=" * 65)
print(f"{'Layer':<10} {'DiM AUC':<15} {'Best LR AUC':<15} {'Best C':<10}")
print("-" * 65)

best_overall = {'layer': None, 'auc': 0, 'method': None}

for layer in data['layers']:
    dim_auc = dim_results[layer].mean()
    
    # Find best C for this layer
    best_c = None
    best_lr_auc = 0
    for C in C_VALUES:
        auc = lr_results[layer][C].mean()
        if auc > best_lr_auc:
            best_lr_auc = auc
            best_c = C
    
    print(f"{layer:<10} {dim_auc:<15.3f} {best_lr_auc:<15.3f} {best_c:<10.2f}")
    
    # Track best overall
    if dim_auc > best_overall['auc']:
        best_overall = {'layer': layer, 'auc': dim_auc, 'method': 'DiM'}
    if best_lr_auc > best_overall['auc']:
        best_overall = {'layer': layer, 'auc': best_lr_auc, 'method': f'LR(C={best_c})'}

print()
print(f"Best overall: Layer {best_overall['layer']}, {best_overall['method']}, AUC = {best_overall['auc']:.3f}")

## PCA Visualization

Project activations to 2D to visualize class separation.

**Caveat:** Overlap in 2D doesn't mean classes are inseparable - the signal might be in dimensions PCA doesn't prioritize.

In [None]:
# Cell 11: PCA visualization for best layer

best_layer = best_overall['layer']
X_best = filtered_activations[best_layer]

# Fit PCA
pca = PCA(n_components=2, random_state=42)
X_2d = pca.fit_transform(X_best)

print(f"PCA on Layer {best_layer}:")
print(f"  Variance explained: PC1={pca.explained_variance_ratio_[0]*100:.1f}%, PC2={pca.explained_variance_ratio_[1]*100:.1f}%")
print(f"  Total: {pca.explained_variance_ratio_.sum()*100:.1f}%")

# Plot
fig, ax = plt.subplots(figsize=(10, 8))

contradiction_mask = filtered_labels == 1
honest_mask = filtered_labels == 0

ax.scatter(
    X_2d[honest_mask, 0], X_2d[honest_mask, 1],
    c='blue', alpha=0.7, s=80, label=f'Honest (n={honest_mask.sum()})',
    edgecolors='white', linewidth=0.5
)
ax.scatter(
    X_2d[contradiction_mask, 0], X_2d[contradiction_mask, 1],
    c='red', alpha=0.7, s=80, label=f'Contradiction (n={contradiction_mask.sum()})',
    edgecolors='white', linewidth=0.5
)

# Add centroids
centroid_honest = X_2d[honest_mask].mean(axis=0)
centroid_contra = X_2d[contradiction_mask].mean(axis=0)

ax.scatter(*centroid_honest, c='blue', s=200, marker='*', edgecolors='black', linewidth=2, zorder=5)
ax.scatter(*centroid_contra, c='red', s=200, marker='*', edgecolors='black', linewidth=2, zorder=5)

# Draw arrow between centroids
ax.annotate('', xy=centroid_contra, xytext=centroid_honest,
            arrowprops=dict(arrowstyle='->', color='green', lw=2))

ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
ax.set_title(f'Geography Question B: Layer {best_layer}\n(Stars = centroids, Arrow = DiM direction)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Detailed Analysis

If the signal is still weak, let's analyze the data more carefully:
1. Do contradictions cluster by difficulty level?
2. Are there specific patterns in which pairs cause contradictions?

In [None]:
# Cell 12: Load geography CSV for detailed analysis

geo_df = pd.read_csv(RUN_DIR / "trajectories/geography.csv")

print("Geography data summary:")
print(f"  Total pairs: {len(geo_df)}")
print()
print("Contradiction rate by difficulty:")
difficulty_stats = geo_df.groupby('difficulty')['is_contradiction'].agg(['sum', 'count', 'mean'])
difficulty_stats.columns = ['n_contradiction', 'n_total', 'rate']
print(difficulty_stats.to_string())
print()

# Show which pairs caused contradictions
print("Contradiction cases:")
contradictions = geo_df[geo_df['is_contradiction'] == True][['pair_id', 'entity_x', 'entity_y', 'difficulty', 'answer_a', 'answer_b']]
print(contradictions.to_string())

In [None]:
# Cell 13: Analyze patterns in contradictions

# Most contradictions should be NO/NO (model says both are "not south of" each other)
answer_patterns = geo_df.groupby(['answer_a', 'answer_b', 'is_contradiction']).size().reset_index(name='count')
print("Answer patterns:")
print(answer_patterns.to_string())
print()

# Check if contradictions are systematically wrong
print("\nFor contradiction cases:")
contra_df = geo_df[geo_df['is_contradiction'] == True]
print(f"  Answer A distribution: {contra_df['answer_a'].value_counts().to_dict()}")
print(f"  Answer B distribution: {contra_df['answer_b'].value_counts().to_dict()}")

In [None]:
# Cell 14: Save results

results_dir = RUN_DIR / "probes"
results_dir.mkdir(exist_ok=True)

# Compile results
save_data = {
    'experiment': 'geography_question_b_only',
    'n_samples': len(filtered_labels),
    'n_contradiction': int(filtered_labels.sum()),
    'n_honest': int((1 - filtered_labels).sum()),
    'n_folds': N_FOLDS,
    'layers': data['layers'],
    
    # DiM results
    'dim_results': {layer: {
        'aucs': dim_results[layer].tolist(),
        'mean': float(dim_results[layer].mean()),
        'std': float(dim_results[layer].std())
    } for layer in data['layers']},
    
    # LR results
    'lr_results': {layer: {
        str(C): {
            'aucs': lr_results[layer][C].tolist(),
            'mean': float(lr_results[layer][C].mean()),
            'std': float(lr_results[layer][C].std())
        } for C in C_VALUES
    } for layer in data['layers']},
    
    'best_overall': best_overall,
}

save_path = results_dir / "geography_question_b_results.pt"
torch.save(save_data, save_path)
print(f"Results saved to: {save_path}")

## Summary

This notebook tested probes on a cleaner dataset:
- Geography domain only (no cross-domain confounds)
- Question B only (the "rationalization" moment)
- 5-fold cross-validation for robust estimation

**Key results:**
- Best layer and method shown above
- If AUC still ~0.5: The rationalization signal may not be linearly separable in residual stream
- If AUC improved: Focusing on the right subset of data helped

**If signal is still weak, possible next steps:**
1. Try different token positions (e.g., first generated token instead of last prompt token)
2. Look at attention patterns instead of residual stream
3. Try non-linear probes
4. Increase dataset size