# Probe Architecture Comparison

This notebook compares 8 novel probe architectures for predicting model confidence.

**Architectures tested:**
1. Default MLP (baseline)
2. AttentionProbe - Self-attention over hidden state chunks
3. ResidualProbe - Deep MLP with skip connections
4. BottleneckProbe - Compression to low-rank representation
5. MultiHeadProbe - Multiple experts with learned aggregation
6. GatedProbe - GLU-style gating
7. SparseProbe - Top-k dimension selection
8. HeteroscedasticProbe - Per-example uncertainty
9. BilinearProbe - Explicit feature interactions

**Setup:** Run on Google Colab with GPU (Runtime > Change runtime type > T4 GPU)

In [None]:
# Install dependencies (Colab)
!pip install -q transformers accelerate bitsandbytes datasets tqdm matplotlib seaborn scikit-learn

In [None]:
# Clone repo and setup path
!git clone https://github.com/joshcliu/deep-learning.git
%cd deep-learning

import sys
sys.path.insert(0, '.')

In [None]:
# Patch for bfloat16 compatibility (8-bit quantized models)
import numpy as np
import torch

# Monkey-patch the extractor to handle bfloat16
from src.models import extractor as extractor_module

_original_extract_batch = extractor_module.HiddenStateExtractor._extract_batch

def _patched_extract_batch(self, texts, layers, max_length):
    """Patched to handle bfloat16 tensors."""
    inputs = self.tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    ).to(self.model.device)

    with torch.no_grad():
        outputs = self.model(**inputs, output_hidden_states=True)

    hidden_states = outputs.hidden_states
    batch_hiddens = []

    for layer_idx in layers:
        layer_hidden = hidden_states[layer_idx + 1]
        token_hiddens = layer_hidden[:, -1, :]
        # Fix: Convert via .tolist() to handle bfloat16
        batch_hiddens.append(np.array(token_hiddens.detach().cpu().tolist(), dtype=np.float32))

    return np.stack(batch_hiddens, axis=1)

extractor_module.HiddenStateExtractor._extract_batch = _patched_extract_batch
print("Patched HiddenStateExtractor for bfloat16 compatibility")

## 1. Load Model

In [None]:
from src.models import ModelLoader, HiddenStateExtractor

# Load Mistral 7B (ungated, no approval needed)
model_name = "mistralai/Mistral-7B-v0.1"
loader = ModelLoader(model_name)
model, tokenizer = loader.load(quantization="8bit", device_map="auto")

print(f"Loaded {model_name}")
print(f"Layers: {loader.config.num_layers}, Hidden dim: {loader.config.hidden_dim}")

## 2. Load Dataset

In [None]:
from src.data import MMLUDataset

# Load MMLU validation set
dataset = MMLUDataset(split="validation")
print(f"Loaded {len(dataset)} examples")

# Sample for experiment
NUM_SAMPLES = 300  # Increase for better results
examples = dataset.sample(NUM_SAMPLES, seed=42)
print(f"Using {len(examples)} examples")

## 3. Generate Answers & Check Correctness

In [None]:
from tqdm import tqdm

def generate_answer(model, tokenizer, prompt, max_new_tokens=32):
    """Generate model's answer to a question."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy for reproducibility
            pad_token_id=tokenizer.eos_token_id,
        )
    
    generated = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return generated.strip()

def check_correctness(generated: str, example) -> bool:
    """Check if generated answer matches correct answer."""
    correct_answer = example.choices[example.answer]
    correct_letter = chr(65 + example.answer)  # A, B, C, D
    
    generated_lower = generated.lower().strip()
    
    # Check for letter match
    if generated_lower.startswith(correct_letter.lower()):
        return True
    
    # Check for answer text match
    if correct_answer.lower() in generated_lower:
        return True
    
    return False

# Generate answers for all examples
print("Generating model answers...")
prompts = []
correctness = []

for example in tqdm(examples):
    prompt = example.format_prompt(style="multiple_choice")
    prompts.append(prompt)
    
    generated = generate_answer(model, tokenizer, prompt)
    is_correct = check_correctness(generated, example)
    correctness.append(int(is_correct))

correctness = np.array(correctness)
print(f"\nAccuracy: {correctness.mean():.1%} ({correctness.sum()}/{len(correctness)})")

## 4. Extract Hidden States

In [None]:
# Extract from middle layer (where uncertainty signal is strongest)
LAYER = 16  # Middle of 32 layers

extractor = HiddenStateExtractor(model, tokenizer)
hidden_states = extractor.extract(
    texts=prompts,
    layers=[LAYER],
    batch_size=8,
    show_progress=True,
)

# Shape: (num_examples, 1, hidden_dim) -> (num_examples, hidden_dim)
X = hidden_states[:, 0, :]
y = correctness

print(f"Hidden states shape: {X.shape}")
print(f"Labels shape: {y.shape}")

## 5. Train/Val/Test Split

In [None]:
from sklearn.model_selection import train_test_split

# 60% train, 20% val, 20% test
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

print(f"Train: {len(X_train)} (acc: {y_train.mean():.1%})")
print(f"Val:   {len(X_val)} (acc: {y_val.mean():.1%})")
print(f"Test:  {len(X_test)} (acc: {y_test.mean():.1%})")

## 6. Define Architectures

In [None]:
from src.probes import (
    CalibratedProbe,
    build_default_network,
    build_attention_network,
    build_residual_network,
    build_bottleneck_network,
    build_multihead_network,
    build_gated_network,
    build_sparse_network,
    build_heteroscedastic_network,
    build_bilinear_network,
)

INPUT_DIM = X.shape[1]  # 4096 for Mistral

# Define all architectures to test
ARCHITECTURES = {
    "Default MLP": lambda: build_default_network(INPUT_DIM, hidden_dim=256),
    "Attention": lambda: build_attention_network(INPUT_DIM, num_chunks=16, num_heads=4),
    "Residual": lambda: build_residual_network(INPUT_DIM, hidden_dim=256, num_blocks=3),
    "Bottleneck": lambda: build_bottleneck_network(INPUT_DIM, bottleneck_dim=64),
    "MultiHead": lambda: build_multihead_network(INPUT_DIM, num_heads=4, head_dim=128),
    "Gated": lambda: build_gated_network(INPUT_DIM, hidden_dim=256, num_layers=2),
    "Sparse (k=256)": lambda: build_sparse_network(INPUT_DIM, k=256, hidden_dim=128),
    "Heteroscedastic": lambda: build_heteroscedastic_network(INPUT_DIM, hidden_dim=256),
    "Bilinear": lambda: build_bilinear_network(INPUT_DIM, num_factors=32, hidden_dim=128),
}

print(f"Testing {len(ARCHITECTURES)} architectures")

## 7. Train All Architectures

In [None]:
from sklearn.metrics import roc_auc_score, brier_score_loss

def compute_ece(confidences, labels, num_bins=10):
    """Compute Expected Calibration Error."""
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    ece = 0.0
    
    for i in range(num_bins):
        mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if mask.sum() > 0:
            bin_conf = confidences[mask].mean()
            bin_acc = labels[mask].mean()
            ece += mask.sum() * abs(bin_conf - bin_acc)
    
    return ece / len(confidences)

# Store results
results = {}

# Training settings - NO early stopping, use scheduler for better convergence
NUM_EPOCHS = 200
PATIENCE = None  # Disable early stopping - train for full epochs

for name, build_fn in ARCHITECTURES.items():
    print(f"\n{'='*50}")
    print(f"Training: {name}")
    print('='*50)
    
    # Build network and probe
    network = build_fn()
    probe = CalibratedProbe(network=network)
    
    # Count parameters
    num_params = sum(p.numel() for p in probe.parameters())
    print(f"Parameters: {num_params:,}")
    
    # Train with NO early stopping
    history = probe.fit(
        X_train, y_train,
        X_val, y_val,
        batch_size=32,
        num_epochs=NUM_EPOCHS,
        patience=PATIENCE,  # None = no early stopping
        use_scheduler=True,  # Cosine annealing LR
        verbose=True,
    )
    
    # Evaluate on test set
    confidences = probe.predict(X_test)
    predictions = (confidences > 0.5).astype(int)
    
    # Compute metrics
    accuracy = (predictions == y_test).mean()
    auroc = roc_auc_score(y_test, confidences)
    brier = brier_score_loss(y_test, confidences)
    ece = compute_ece(confidences, y_test)
    
    results[name] = {
        "accuracy": accuracy,
        "auroc": auroc,
        "brier": brier,
        "ece": ece,
        "num_params": num_params,
        "best_epoch": history["best_epoch"],
        "confidences": confidences,
    }
    
    print(f"\nTest Results:")
    print(f"  Accuracy: {accuracy:.3f}")
    print(f"  AUROC:    {auroc:.3f}")
    print(f"  Brier:    {brier:.4f}")
    print(f"  ECE:      {ece:.4f}")

## 8. Results Comparison

In [None]:
import pandas as pd

# Create results DataFrame
df = pd.DataFrame({
    name: {
        "Accuracy": f"{r['accuracy']:.3f}",
        "AUROC": f"{r['auroc']:.3f}",
        "Brier Score": f"{r['brier']:.4f}",
        "ECE": f"{r['ece']:.4f}",
        "Parameters": f"{r['num_params']:,}",
    }
    for name, r in results.items()
}).T

print("\n" + "="*70)
print("ARCHITECTURE COMPARISON")
print("="*70)
print(df.to_string())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

names = list(results.keys())
colors = sns.color_palette("husl", len(names))

# 1. AUROC comparison
ax1 = axes[0, 0]
aurocs = [results[n]["auroc"] for n in names]
bars = ax1.barh(names, aurocs, color=colors)
ax1.set_xlabel("AUROC (higher is better)")
ax1.set_title("Discrimination: AUROC")
ax1.set_xlim(0.5, 1.0)
for bar, val in zip(bars, aurocs):
    ax1.text(val + 0.01, bar.get_y() + bar.get_height()/2, f"{val:.3f}", va='center')

# 2. Brier Score comparison
ax2 = axes[0, 1]
briers = [results[n]["brier"] for n in names]
bars = ax2.barh(names, briers, color=colors)
ax2.set_xlabel("Brier Score (lower is better)")
ax2.set_title("Calibration: Brier Score")
for bar, val in zip(bars, briers):
    ax2.text(val + 0.005, bar.get_y() + bar.get_height()/2, f"{val:.4f}", va='center')

# 3. ECE comparison
ax3 = axes[1, 0]
eces = [results[n]["ece"] for n in names]
bars = ax3.barh(names, eces, color=colors)
ax3.set_xlabel("ECE (lower is better)")
ax3.set_title("Calibration: Expected Calibration Error")
for bar, val in zip(bars, eces):
    ax3.text(val + 0.005, bar.get_y() + bar.get_height()/2, f"{val:.4f}", va='center')

# 4. Parameters vs AUROC (efficiency)
ax4 = axes[1, 1]
params = [results[n]["num_params"] for n in names]
for i, name in enumerate(names):
    ax4.scatter(params[i], aurocs[i], s=150, c=[colors[i]], label=name, edgecolors='black')
ax4.set_xlabel("Number of Parameters")
ax4.set_ylabel("AUROC")
ax4.set_title("Efficiency: Parameters vs Performance")
ax4.set_xscale('log')
ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

plt.tight_layout()
plt.savefig("architecture_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved: architecture_comparison.png")

## 9. Reliability Diagrams

In [None]:
def plot_reliability_diagram(confidences, labels, title, ax, num_bins=10):
    """Plot reliability diagram."""
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2
    
    bin_accs = []
    bin_confs = []
    bin_counts = []
    
    for i in range(num_bins):
        mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if mask.sum() > 0:
            bin_accs.append(labels[mask].mean())
            bin_confs.append(confidences[mask].mean())
            bin_counts.append(mask.sum())
        else:
            bin_accs.append(np.nan)
            bin_confs.append(np.nan)
            bin_counts.append(0)
    
    # Plot
    ax.bar(bin_centers, bin_accs, width=0.08, alpha=0.7, label='Accuracy')
    ax.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    ax.set_xlabel('Confidence')
    ax.set_ylabel('Accuracy')
    ax.set_title(title)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(loc='lower right')

# Plot reliability diagrams for top architectures
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
axes = axes.flatten()

for i, (name, r) in enumerate(results.items()):
    if i < 9:
        plot_reliability_diagram(
            r["confidences"], y_test,
            f"{name}\nECE={r['ece']:.4f}",
            axes[i]
        )

plt.tight_layout()
plt.savefig("reliability_diagrams.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved: reliability_diagrams.png")

## 10. Summary & Recommendations

In [None]:
# Find best architectures for each metric
best_auroc = max(results.items(), key=lambda x: x[1]["auroc"])
best_brier = min(results.items(), key=lambda x: x[1]["brier"])
best_ece = min(results.items(), key=lambda x: x[1]["ece"])
best_efficiency = max(results.items(), key=lambda x: x[1]["auroc"] / np.log10(x[1]["num_params"]))

print("="*60)
print("RECOMMENDATIONS")
print("="*60)
print(f"\nBest Discrimination (AUROC): {best_auroc[0]}")
print(f"  AUROC = {best_auroc[1]['auroc']:.4f}")

print(f"\nBest Calibration (Brier): {best_brier[0]}")
print(f"  Brier = {best_brier[1]['brier']:.4f}")

print(f"\nBest Calibration (ECE): {best_ece[0]}")
print(f"  ECE = {best_ece[1]['ece']:.4f}")

print(f"\nMost Efficient (AUROC/log(params)): {best_efficiency[0]}")
print(f"  AUROC = {best_efficiency[1]['auroc']:.4f}, Params = {best_efficiency[1]['num_params']:,}")

print("\n" + "="*60)
print("NOTES")
print("="*60)
print("""
- Lower Brier/ECE = better calibration (confidence matches accuracy)
- Higher AUROC = better discrimination (separating correct/incorrect)
- SparseProbe is most interpretable (shows which dimensions matter)
- HeteroscedasticProbe handles varying example difficulty
- For production: prioritize calibration (Brier/ECE)
- For research: prioritize AUROC and interpretability
""")

## 11. Sparse Probe Analysis (Interpretability)

In [None]:
# If SparseProbe performed well, analyze which dimensions it selected
from src.probes.architectures import TopKSparseNetwork

# Retrain sparse probe to access its learned importance weights
sparse_network = build_sparse_network(INPUT_DIM, k=256, hidden_dim=128)
sparse_probe = CalibratedProbe(network=sparse_network)
sparse_probe.fit(X_train, y_train, X_val, y_val, num_epochs=200, patience=None, verbose=False)

# Get importance scores using the new API
importance = sparse_probe.network.get_importance_scores().cpu().numpy()

# Find top dimensions
top_k = 20
top_indices = np.argsort(importance)[-top_k:][::-1]
top_scores = importance[top_indices]

print(f"Top {top_k} most important dimensions for confidence:")
print("="*40)
for idx, score in zip(top_indices, top_scores):
    print(f"  Dim {idx:4d}: importance = {score:.4f}")

# Plot importance distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.bar(range(top_k), top_scores)
plt.xlabel("Rank")
plt.ylabel("Importance Score")
plt.title(f"Top {top_k} Dimensions")
plt.xticks(range(top_k), top_indices, rotation=45)

plt.subplot(1, 2, 2)
plt.hist(importance, bins=50, edgecolor='black')
plt.xlabel("Importance Score")
plt.ylabel("Count")
plt.title("Importance Distribution (All Dimensions)")
# Show threshold for top-256 dimensions
sorted_importance = np.sort(importance)[::-1]
if len(sorted_importance) > 256:
    threshold = sorted_importance[255]
    plt.axvline(threshold, color='r', linestyle='--', label=f'Top 256 threshold ({threshold:.3f})')
    plt.legend()

plt.tight_layout()
plt.savefig("sparse_probe_importance.png", dpi=150)
plt.show()

print("\nSaved: sparse_probe_importance.png")