# Multi-Source Confidence Probe

This notebook tests the **MultiSourceConfidenceNetwork** which combines:
- Hidden states from **k quartile layers** (internal uncertainty)
- Output **logits** for answer choices (expressed confidence)

The hypothesis: Internal uncertainty (hidden states) may differ from expressed confidence (logits). By combining both, we can detect miscalibration.

**Comparison architectures:**
1. Linear (single layer baseline)
2. LayerEnsemble (multi-layer, no logits)
3. MultiSource (multi-layer + logits) ← NEW

**Setup:** Run on Google Colab with GPU (Runtime > Change runtime type > T4 GPU)

In [None]:
%cd /content

In [None]:
!rm -rf /content/deep-learning
!rm -rf deep-learning*

In [None]:
# Install dependencies (Colab)
!pip install -q transformers accelerate bitsandbytes datasets tqdm matplotlib seaborn scikit-learn

In [None]:
# Clone repo and setup path
!git clone https://github.com/joshcliu/deep-learning.git
%cd deep-learning

import sys
sys.path.insert(0, '.')

In [None]:
# Patch for bfloat16 compatibility (8-bit quantized models)
import numpy as np
import torch
from src.models import extractor as extractor_module

def _patched_extract_batch(self, texts, layers, max_length, token_position):
    """Patched to handle bfloat16 safely."""
    encodings = self.tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    encodings = {k: v.to(self.device) for k, v in encodings.items()}

    with torch.no_grad():
        outputs = self.model(
            **encodings,
            output_hidden_states=True,
            return_dict=True,
        )

    hidden_states = outputs.hidden_states
    batch_hiddens = []

    for layer_idx in layers:
        layer_hiddens = hidden_states[layer_idx + 1]

        if token_position == "last":
            attention_mask = encodings["attention_mask"]
            seq_lengths = attention_mask.sum(dim=1) - 1
            token_hiddens = layer_hiddens[
                torch.arange(layer_hiddens.size(0), device=self.device),
                seq_lengths
            ]
        elif token_position == "cls":
            token_hiddens = layer_hiddens[:, 0, :]
        elif token_position == "mean":
            attention_mask = encodings["attention_mask"].unsqueeze(-1)
            masked_hiddens = layer_hiddens * attention_mask
            token_hiddens = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)
        else:
            raise ValueError(f"Unknown token_position: {token_position}")

        token_hiddens = token_hiddens.detach().cpu().to(torch.float32).numpy()
        batch_hiddens.append(token_hiddens)

    return np.stack(batch_hiddens, axis=1)

extractor_module.HiddenStateExtractor._extract_batch = _patched_extract_batch
print("Patched HiddenStateExtractor for bfloat16 compatibility.")

## 1. Load Model

In [None]:
from src.models import ModelLoader, HiddenStateExtractor

model_name = "Qwen/Qwen2.5-7B"
loader = ModelLoader(model_name)
model, tokenizer = loader.load(quantization="8bit", device_map="auto")

print(f"Loaded {model_name}")
print(f"Layers: {loader.config.num_layers}, Hidden dim: {loader.config.hidden_dim}")

# Define quartile layers
num_layers = loader.config.num_layers
hidden_dim = loader.config.hidden_dim
QUARTILE_LAYERS = [
    num_layers // 4,      # Q1 (Early)
    num_layers // 2,      # Q2 (Middle)
    3 * num_layers // 4,  # Q3 (Late)
    num_layers - 1        # Q4 (Final)
]
print(f"Quartile layers: {QUARTILE_LAYERS}")

## 2. Load Dataset

In [None]:
from src.data import MMLUDataset

# IMPORTANT: Use 'test' split (14K examples), not 'validation' (1.5K)
# If you see "Loaded 1531 examples", you're using validation by mistake
dataset = MMLUDataset(split="test")  # ← Make sure this says 'test'
print(f"Loaded {len(dataset)} MMLU examples")
print(f"  (If this shows 1531, change split to 'test' instead of 'validation')")

NUM_SAMPLES = 5000  # Use 5000 out of 14K for robust evaluation
examples = dataset.sample(NUM_SAMPLES, seed=42)
print(f"Using {len(examples)} examples for training")

## 3. Extract Hidden States + Logits

This is the key difference: we extract BOTH hidden states from multiple layers AND the logits for each answer choice.

In [None]:
from tqdm import tqdm

def get_answer_token_ids(tokenizer, choices=['A', 'B', 'C', 'D']):
    """Get token IDs for answer choices."""
    token_ids = []
    for choice in choices:
        # Try different formats
        ids = tokenizer.encode(choice, add_special_tokens=False)
        if len(ids) == 1:
            token_ids.append(ids[0])
        else:
            # Try with space prefix
            ids = tokenizer.encode(f" {choice}", add_special_tokens=False)
            token_ids.append(ids[-1])  # Take last token
    return token_ids

ANSWER_TOKEN_IDS = get_answer_token_ids(tokenizer)
print(f"Answer token IDs: {ANSWER_TOKEN_IDS}")
print(f"Decoded: {[tokenizer.decode([tid]) for tid in ANSWER_TOKEN_IDS]}")

In [None]:
def extract_hidden_states_and_logits(model, tokenizer, extractor, examples, layers, answer_token_ids, batch_size=4):
    """
    Extract hidden states from multiple layers AND logits for answer choices.
    
    Returns:
        hidden_states: (num_examples, num_layers, hidden_dim)
        logits: (num_examples, num_choices)
        correctness: (num_examples,) binary labels
    """
    all_hidden_states = []
    all_logits = []
    all_correctness = []
    
    # Process in batches
    for i in tqdm(range(0, len(examples), batch_size), desc="Extracting"):
        batch_examples = examples[i:i+batch_size]
        prompts = [ex.format_prompt(style="multiple_choice") for ex in batch_examples]
        
        # Tokenize
        encodings = tokenizer(
            prompts,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model(
                **encodings,
                output_hidden_states=True,
                return_dict=True,
            )
        
        # Get last token position for each example
        attention_mask = encodings["attention_mask"]
        seq_lengths = attention_mask.sum(dim=1) - 1
        
        # Extract hidden states from each layer
        batch_hiddens = []
        for layer_idx in layers:
            layer_hiddens = outputs.hidden_states[layer_idx + 1]  # +1 for embeddings
            # Get last token hidden state
            token_hiddens = layer_hiddens[
                torch.arange(layer_hiddens.size(0), device=model.device),
                seq_lengths
            ]
            batch_hiddens.append(token_hiddens.cpu().to(torch.float32).numpy())
        
        # Stack: (batch, num_layers, hidden_dim)
        batch_hidden_states = np.stack(batch_hiddens, axis=1)
        all_hidden_states.append(batch_hidden_states)
        
        # Extract logits for answer choices at last token position
        last_logits = outputs.logits[
            torch.arange(outputs.logits.size(0), device=model.device),
            seq_lengths
        ]  # (batch, vocab_size)
        
        # Get logits for answer token IDs only
        answer_logits = last_logits[:, answer_token_ids]  # (batch, num_choices)
        all_logits.append(answer_logits.cpu().to(torch.float32).numpy())
        
        # Determine correctness based on which answer has highest logit
        predicted_answers = answer_logits.argmax(dim=-1).cpu().numpy()
        correct_answers = np.array([ex.answer for ex in batch_examples])
        batch_correctness = (predicted_answers == correct_answers).astype(np.float32)
        all_correctness.append(batch_correctness)
    
    # Concatenate all batches
    hidden_states = np.concatenate(all_hidden_states, axis=0)
    logits = np.concatenate(all_logits, axis=0)
    correctness = np.concatenate(all_correctness, axis=0)
    
    return hidden_states, logits, correctness

# Extract everything
extractor = HiddenStateExtractor(model, tokenizer)
hidden_states, logits, correctness = extract_hidden_states_and_logits(
    model, tokenizer, extractor, examples, 
    QUARTILE_LAYERS, ANSWER_TOKEN_IDS,
    batch_size=4
)

print(f"\nExtracted data shapes:")
print(f"  Hidden states: {hidden_states.shape}  (examples, layers, hidden_dim)")
print(f"  Logits: {logits.shape}  (examples, num_choices)")
print(f"  Correctness: {correctness.shape}")
print(f"\nAccuracy (argmax logits): {correctness.mean():.1%}")

## 4. Prepare Data for Different Architectures

In [None]:
from sklearn.model_selection import train_test_split

# For Linear probe: use only middle layer
middle_layer_idx = 1  # Index 1 = Q2 (middle layer) in our 4-layer extraction
X_linear = hidden_states[:, middle_layer_idx, :]  # (examples, hidden_dim)

# For LayerEnsemble: flatten all layers
X_ensemble = hidden_states.reshape(hidden_states.shape[0], -1)  # (examples, num_layers * hidden_dim)

# For MultiSource: flatten hidden states + concatenate logits
X_multisource = np.concatenate([
    hidden_states.reshape(hidden_states.shape[0], -1),  # (examples, num_layers * hidden_dim)
    logits  # (examples, num_choices)
], axis=1)

y = correctness

print("Data shapes for each architecture:")
print(f"  Linear (single layer): {X_linear.shape}")
print(f"  LayerEnsemble (multi-layer): {X_ensemble.shape}")
print(f"  MultiSource (multi-layer + logits): {X_multisource.shape}")

# Split data (60% train, 20% val, 20% test)
# Split Linear data
X_train_lin, X_temp_lin, y_train, y_temp = train_test_split(
    X_linear, y, test_size=0.4, random_state=42, stratify=y
)
X_val_lin, X_test_lin, y_val, y_test = train_test_split(
    X_temp_lin, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

# Split Ensemble data (same indices)
X_train_ens, X_temp_ens, _, _ = train_test_split(
    X_ensemble, y, test_size=0.4, random_state=42, stratify=y
)
X_val_ens, X_test_ens, _, _ = train_test_split(
    X_temp_ens, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

# Split MultiSource data (same indices)
X_train_ms, X_temp_ms, _, _ = train_test_split(
    X_multisource, y, test_size=0.4, random_state=42, stratify=y
)
X_val_ms, X_test_ms, _, _ = train_test_split(
    X_temp_ms, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"\nSplit sizes:")
print(f"  Train: {len(X_train_lin)} (acc: {y_train.mean():.1%})")
print(f"  Val:   {len(X_val_lin)} (acc: {y_val.mean():.1%})")
print(f"  Test:  {len(X_test_lin)} (acc: {y_test.mean():.1%})")

## 5. Define Architectures

In [None]:
from src.probes import (
    CalibratedProbe,
    build_default_network,
    build_layer_ensemble_network,
    build_multi_source_network,
)

NUM_LAYERS = len(QUARTILE_LAYERS)
NUM_CHOICES = 4

# Define architectures with their data
ARCHITECTURES = {
    "Linear (middle layer)": {
        "build_fn": lambda: build_default_network(hidden_dim, hidden_dim=None),
        "X_train": X_train_lin,
        "X_val": X_val_lin,
        "X_test": X_test_lin,
    },
    "LayerEnsemble (4 layers)": {
        "build_fn": lambda: build_layer_ensemble_network(
            input_dim=NUM_LAYERS * hidden_dim,
            num_layers=NUM_LAYERS,
            layer_probe_hidden=64,
        ),
        "X_train": X_train_ens,
        "X_val": X_val_ens,
        "X_test": X_test_ens,
    },
    "MultiSource (4 layers + logits)": {
        "build_fn": lambda: build_multi_source_network(
            hidden_dim=hidden_dim,
            num_layers=NUM_LAYERS,
            num_choices=NUM_CHOICES,
            use_logits=True,
        ),
        "X_train": X_train_ms,
        "X_val": X_val_ms,
        "X_test": X_test_ms,
    },
}

print(f"Testing {len(ARCHITECTURES)} architectures")

## 6. Train All Architectures

In [None]:
from sklearn.metrics import roc_auc_score, brier_score_loss

def compute_ece(confidences, labels, num_bins=10):
    """Compute Expected Calibration Error."""
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    ece = 0.0
    for i in range(num_bins):
        mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if mask.sum() > 0:
            bin_conf = confidences[mask].mean()
            bin_acc = labels[mask].mean()
            ece += mask.sum() * abs(bin_conf - bin_acc)
    return ece / len(confidences)

# Training settings
NUM_EPOCHS = 200
PATIENCE = None  # No early stopping

results = {}

for name, config in ARCHITECTURES.items():
    print(f"\n{'='*60}")
    print(f"Training: {name}")
    print('='*60)
    
    # Build network and probe
    network = config["build_fn"]()
    probe = CalibratedProbe(network=network)
    
    # Count parameters
    num_params = sum(p.numel() for p in probe.parameters())
    print(f"Parameters: {num_params:,}")
    
    # Train
    history = probe.fit(
        config["X_train"], y_train,
        config["X_val"], y_val,
        batch_size=32,
        num_epochs=NUM_EPOCHS,
        patience=PATIENCE,
        use_scheduler=True,
        verbose=True,
    )
    
    # Evaluate on test set
    confidences = probe.predict(config["X_test"])
    predictions = (confidences > 0.5).astype(int)
    
    # Compute metrics
    accuracy = (predictions == y_test).mean()
    auroc = roc_auc_score(y_test, confidences)
    brier = brier_score_loss(y_test, confidences)
    ece = compute_ece(confidences, y_test)
    
    results[name] = {
        "accuracy": accuracy,
        "auroc": auroc,
        "brier": brier,
        "ece": ece,
        "num_params": num_params,
        "best_epoch": history["best_epoch"],
        "confidences": confidences,
        "probe": probe,
    }
    
    print(f"\nTest Results:")
    print(f"  Accuracy: {accuracy:.3f}")
    print(f"  AUROC:    {auroc:.3f}")
    print(f"  Brier:    {brier:.4f}")
    print(f"  ECE:      {ece:.4f}")

## 7. Results Comparison

In [None]:
import pandas as pd

# Create results DataFrame
df = pd.DataFrame({
    name: {
        "Accuracy": f"{r['accuracy']:.3f}",
        "AUROC": f"{r['auroc']:.3f}",
        "Brier Score": f"{r['brier']:.4f}",
        "ECE": f"{r['ece']:.4f}",
        "Parameters": f"{r['num_params']:,}",
    }
    for name, r in results.items()
}).T

print("\n" + "="*70)
print("ARCHITECTURE COMPARISON")
print("="*70)
print(df.to_string())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

names = list(results.keys())
colors = ['#3498db', '#2ecc71', '#e74c3c']  # Blue, Green, Red

# 1. AUROC comparison
ax1 = axes[0]
aurocs = [results[n]["auroc"] for n in names]
bars = ax1.barh(names, aurocs, color=colors)
ax1.set_xlabel("AUROC (higher is better)")
ax1.set_title("Discrimination: AUROC")
ax1.set_xlim(0.5, 1.0)
for bar, val in zip(bars, aurocs):
    ax1.text(val + 0.01, bar.get_y() + bar.get_height()/2, f"{val:.3f}", va='center')

# 2. Brier Score comparison
ax2 = axes[1]
briers = [results[n]["brier"] for n in names]
bars = ax2.barh(names, briers, color=colors)
ax2.set_xlabel("Brier Score (lower is better)")
ax2.set_title("Calibration: Brier Score")
for bar, val in zip(bars, briers):
    ax2.text(val + 0.005, bar.get_y() + bar.get_height()/2, f"{val:.4f}", va='center')

# 3. ECE comparison
ax3 = axes[2]
eces = [results[n]["ece"] for n in names]
bars = ax3.barh(names, eces, color=colors)
ax3.set_xlabel("ECE (lower is better)")
ax3.set_title("Calibration: ECE")
for bar, val in zip(bars, eces):
    ax3.text(val + 0.005, bar.get_y() + bar.get_height()/2, f"{val:.4f}", va='center')

plt.tight_layout()
plt.savefig("multi_source_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved: multi_source_comparison.png")

## 8. Reliability Diagrams

In [None]:
def plot_reliability_diagram(confidences, labels, title, ax, num_bins=10):
    """Plot reliability diagram."""
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2

    bin_accs = []
    bin_confs = []

    for i in range(num_bins):
        mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if mask.sum() > 0:
            bin_accs.append(labels[mask].mean())
            bin_confs.append(confidences[mask].mean())
        else:
            bin_accs.append(np.nan)
            bin_confs.append(np.nan)

    ax.bar(bin_centers, bin_accs, width=0.08, alpha=0.7, label='Accuracy')
    ax.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    ax.set_xlabel('Confidence')
    ax.set_ylabel('Accuracy')
    ax.set_title(title)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(loc='lower right')

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, (name, r) in enumerate(results.items()):
    plot_reliability_diagram(
        r["confidences"], y_test,
        f"{name}\nECE={r['ece']:.4f}",
        axes[i]
    )

plt.tight_layout()
plt.savefig("multi_source_reliability.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved: multi_source_reliability.png")

## 9. Analyze MultiSource Layer Weights

In [None]:
# Get learned layer weights from MultiSource probe
ms_probe = results["MultiSource (4 layers + logits)"]["probe"]
layer_weights = ms_probe.network.get_layer_weights().numpy()

print("Learned Layer Weights (MultiSource):")
print("="*40)
layer_names = ["Early (Q1)", "Middle (Q2)", "Late (Q3)", "Final (Q4)"]
for i, (layer_idx, weight) in enumerate(zip(QUARTILE_LAYERS, layer_weights)):
    print(f"  Layer {layer_idx:2d} ({layer_names[i]:12s}): {weight:.4f}")

# Visualize
fig, ax = plt.subplots(1, 1, figsize=(8, 5))

bar_labels = [f"Layer {idx}\n({name})" for idx, name in zip(QUARTILE_LAYERS, layer_names)]
colors = ['lightblue', 'blue', 'darkblue', 'navy']

ax.bar(bar_labels, layer_weights, color=colors, edgecolor='black', linewidth=1.5)
ax.set_ylabel('Learned Weight', fontsize=12)
ax.set_title('MultiSource: Learned Layer Importance', fontsize=14, fontweight='bold')
ax.set_ylim([0, max(layer_weights) * 1.2])
ax.grid(axis='y', alpha=0.3)

for i, weight in enumerate(layer_weights):
    ax.text(i, weight + 0.01, f'{weight:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig("multi_source_layer_weights.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved: multi_source_layer_weights.png")

## 10. Summary

In [None]:
print("="*70)
print("SUMMARY: Multi-Source Confidence Probe")
print("="*70)

# Find best for each metric
best_auroc_name = max(results, key=lambda x: results[x]['auroc'])
best_brier_name = min(results, key=lambda x: results[x]['brier'])
best_ece_name = min(results, key=lambda x: results[x]['ece'])

print(f"\nBest AUROC:  {best_auroc_name}")
print(f"  AUROC = {results[best_auroc_name]['auroc']:.4f}")

print(f"\nBest Brier:  {best_brier_name}")
print(f"  Brier = {results[best_brier_name]['brier']:.4f}")

print(f"\nBest ECE:    {best_ece_name}")
print(f"  ECE = {results[best_ece_name]['ece']:.4f}")

# Compare MultiSource to baselines
ms_results = results["MultiSource (4 layers + logits)"]
lin_results = results["Linear (middle layer)"]
ens_results = results["LayerEnsemble (4 layers)"]

print("\n" + "="*70)
print("MultiSource vs Baselines:")
print("="*70)

print(f"\nAUROC improvement over Linear:      {ms_results['auroc'] - lin_results['auroc']:+.4f}")
print(f"AUROC improvement over LayerEnsemble: {ms_results['auroc'] - ens_results['auroc']:+.4f}")

print(f"\nBrier improvement over Linear:      {lin_results['brier'] - ms_results['brier']:+.4f} (lower is better)")
print(f"Brier improvement over LayerEnsemble: {ens_results['brier'] - ms_results['brier']:+.4f}")

print("\n" + "="*70)
print("KEY INSIGHT:")
print("="*70)
if ms_results['auroc'] > max(lin_results['auroc'], ens_results['auroc']):
    print("MultiSource (hidden states + logits) outperforms both baselines!")
    print("→ Combining internal uncertainty with expressed confidence helps.")
elif ms_results['auroc'] > lin_results['auroc']:
    print("MultiSource outperforms single-layer Linear but not LayerEnsemble.")
    print("→ Multi-layer information helps, but logits may not add much.")
else:
    print("Simple Linear probe performs competitively.")
    print("→ Task may be linearly separable in hidden space.")