# Calibrated Confidence Probe - Model Self-Knowledge

This notebook trains a probe to predict whether the model's **own generated answer** is correct.

**Key differences from basic layer analysis:**
1. Model generates its own answers (not pre-defined correct/incorrect)
2. Probe predicts confidence in model's own answer
3. Uses Brier score loss: `(confidence - correct)^2`
   - High confidence + wrong = **very bad** (high loss)
   - Low confidence + wrong = acceptable
   - High confidence + correct = good

**Before running**: Runtime > Change runtime type > GPU (T4 or better)

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q torch transformers accelerate bitsandbytes datasets scikit-learn tqdm loguru matplotlib seaborn

In [None]:
# Clone repository
%cd /content
!rm -rf deep-learning
!git clone https://github.com/joshcliu/deep-learning.git
%cd deep-learning

In [None]:
# Setup path and imports
import sys
sys.path.insert(0, '.')

import numpy as np
import torch
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
NUM_SAMPLES = 300  # More samples for better training
QUANTIZATION = "8bit"
PROBE_LAYERS = [8, 16, 24, 31]  # Layers to probe
MAX_NEW_TOKENS = 10  # For answer generation

In [None]:
# Load model
from src.models import ModelLoader

print(f"Loading {MODEL_NAME}...")
loader = ModelLoader(model_name=MODEL_NAME)
quantization = None if QUANTIZATION == "none" else QUANTIZATION
model, tokenizer = loader.load(quantization=quantization, device_map="auto")

model_info = loader.get_model_info()
num_layers = model_info["num_layers"]
hidden_dim = model_info["hidden_dim"]
print(f"Loaded: {num_layers} layers, hidden_dim={hidden_dim}")

In [None]:
# Load MMLU dataset
from src.data import MMLUDataset

print("Loading MMLU...")
dataset = MMLUDataset(split="test")
print(f"Dataset: {len(dataset)} questions")

In [None]:
# Patch extractor for bfloat16 compatibility
import src.models.extractor as extractor_module

def patched_extract_batch(self, texts, layers, max_length, token_position):
    """Patched version that handles bfloat16."""
    encodings = self.tokenizer(
        texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt"
    )
    encodings = {k: v.to(self.device) for k, v in encodings.items()}

    outputs = self.model(**encodings, output_hidden_states=True, return_dict=True)

    batch_hiddens = []
    for layer_idx in layers:
        layer_hiddens = outputs.hidden_states[layer_idx + 1]

        if token_position == "last":
            attention_mask = encodings["attention_mask"]
            sequence_lengths = attention_mask.sum(dim=1) - 1
            token_hiddens = layer_hiddens[
                torch.arange(layer_hiddens.size(0), device=self.device),
                sequence_lengths,
            ]
        elif token_position == "cls":
            token_hiddens = layer_hiddens[:, 0, :]
        elif token_position == "mean":
            attention_mask = encodings["attention_mask"].unsqueeze(-1)
            masked_hiddens = layer_hiddens * attention_mask
            token_hiddens = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)

        batch_hiddens.append(np.array(token_hiddens.detach().cpu().tolist(), dtype=np.float32))

    return np.stack(batch_hiddens, axis=1)

extractor_module.HiddenStateExtractor._extract_batch = patched_extract_batch
print("Extractor patched for bfloat16 compatibility")

In [None]:
def generate_model_answers(model, tokenizer, dataset, num_samples, max_new_tokens=10):
    """
    Have the model generate answers and check correctness.
    
    Returns:
        prompts: List of prompts (question + model's answer)
        correctness: Binary labels (1 if model was correct, 0 otherwise)
        model_answers: The generated answers
        correct_answers: The ground truth answers
    """
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    
    prompts = []
    correctness = []
    model_answers = []
    correct_answers = []
    
    for idx in tqdm(indices, desc="Generating answers"):
        example = dataset[idx]
        question = example.question
        choices = example.choices
        correct_idx = example.answer
        correct_answer = choices[correct_idx]
        
        # Format as multiple choice
        choice_str = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(choices)])
        prompt = f"Question: {question}\n{choice_str}\nAnswer:"
        
        # Generate model's answer
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,  # Greedy decoding
                pad_token_id=tokenizer.eos_token_id,
            )
        
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        model_answer = generated[len(prompt):].strip()
        
        # Check if correct (look for the letter or the answer text)
        correct_letter = chr(65 + correct_idx)  # A, B, C, D
        is_correct = (
            model_answer.upper().startswith(correct_letter) or
            correct_answer.lower() in model_answer.lower()
        )
        
        # Store the full prompt + answer for hidden state extraction
        full_text = f"{prompt} {model_answer}"
        
        prompts.append(full_text)
        correctness.append(1 if is_correct else 0)
        model_answers.append(model_answer)
        correct_answers.append(correct_answer)
    
    return prompts, np.array(correctness), model_answers, correct_answers

print("Generating model answers...")
prompts, correctness, model_answers, correct_answers = generate_model_answers(
    model, tokenizer, dataset, NUM_SAMPLES, MAX_NEW_TOKENS
)

accuracy = correctness.mean()
print(f"\nModel accuracy: {accuracy:.1%} ({correctness.sum()}/{len(correctness)} correct)")
print(f"Label distribution: {np.bincount(correctness)}")

In [None]:
# Show some examples
print("\n" + "="*60)
print("SAMPLE GENERATIONS")
print("="*60)

for i in range(min(5, len(prompts))):
    status = "CORRECT" if correctness[i] else "WRONG"
    print(f"\n[{status}] Model: '{model_answers[i]}' | Truth: '{correct_answers[i]}'")

In [None]:
# Extract hidden states
from src.models import HiddenStateExtractor

print(f"\nExtracting hidden states from layers {PROBE_LAYERS}...")
extractor = HiddenStateExtractor(model, tokenizer)

all_hiddens = {}
for layer in tqdm(PROBE_LAYERS, desc="Extracting layers"):
    hiddens = extractor.extract(
        texts=prompts,
        layers=[layer],
        max_length=512,
        token_position="last",
        batch_size=8,  # Smaller batch for memory
    )
    all_hiddens[layer] = hiddens[:, 0, :]
    print(f"Layer {layer}: {all_hiddens[layer].shape}")

In [None]:
# Split data
print("\nSplitting data (70/30)...")
indices = np.arange(len(correctness))
train_idx, val_idx = train_test_split(
    indices, test_size=0.3, random_state=42, stratify=correctness
)
print(f"Train: {len(train_idx)}, Val: {len(val_idx)}")
print(f"Train correct: {correctness[train_idx].sum()}, Val correct: {correctness[val_idx].sum()}")

In [None]:
# Train calibrated probes
from src.probes import CalibratedProbe
from src.evaluation import compute_ece, compute_auroc

print(f"\nTraining calibrated probes...")
results = {}

for layer in tqdm(PROBE_LAYERS, desc="Training probes"):
    hiddens = all_hiddens[layer]
    
    X_train, X_val = hiddens[train_idx], hiddens[val_idx]
    y_train, y_val = correctness[train_idx], correctness[val_idx]
    
    # Train with MLP probe and Brier loss
    probe = CalibratedProbe(
        input_dim=hidden_dim,
        hidden_dim=256,  # MLP with hidden layer
        dropout=0.1,
        lr=1e-3,
    )
    
    history = probe.fit(
        X_train, y_train,
        X_val, y_val,
        num_epochs=100,
        patience=15,
        batch_size=32,
        verbose=False,
    )
    
    # Evaluate
    val_conf = probe.predict(X_val)
    val_preds = (val_conf > 0.5).astype(int)
    
    # Metrics
    accuracy = (val_preds == y_val).mean()
    brier = ((val_conf - y_val) ** 2).mean()
    ece, _ = compute_ece(val_conf, val_preds, y_val)
    auroc = compute_auroc(val_conf, y_val)
    
    results[layer] = {
        "accuracy": accuracy,
        "brier": brier,
        "ece": ece,
        "auroc": auroc,
        "best_epoch": history["best_epoch"],
        "probe": probe,
        "val_conf": val_conf,
    }
    
    print(f"Layer {layer:2d}: Acc={accuracy:.3f}, Brier={brier:.3f}, ECE={ece:.3f}, AUROC={auroc:.3f}")

In [None]:
# Visualize results
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

layers = sorted(results.keys())
accs = [results[l]["accuracy"] for l in layers]
briers = [results[l]["brier"] for l in layers]
eces = [results[l]["ece"] for l in layers]
aurocs = [results[l]["auroc"] for l in layers]

# Plot metrics
axes[0, 0].bar(range(len(layers)), accs, color='steelblue')
axes[0, 0].set_xticks(range(len(layers)))
axes[0, 0].set_xticklabels([f"L{l}" for l in layers])
axes[0, 0].set_ylabel("Accuracy")
axes[0, 0].set_title("Probe Accuracy by Layer")
axes[0, 0].axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Random')

axes[0, 1].bar(range(len(layers)), briers, color='coral')
axes[0, 1].set_xticks(range(len(layers)))
axes[0, 1].set_xticklabels([f"L{l}" for l in layers])
axes[0, 1].set_ylabel("Brier Score (lower=better)")
axes[0, 1].set_title("Brier Score by Layer")

axes[1, 0].bar(range(len(layers)), eces, color='forestgreen')
axes[1, 0].set_xticks(range(len(layers)))
axes[1, 0].set_xticklabels([f"L{l}" for l in layers])
axes[1, 0].set_ylabel("ECE (lower=better)")
axes[1, 0].set_title("Expected Calibration Error by Layer")

axes[1, 1].bar(range(len(layers)), aurocs, color='purple')
axes[1, 1].set_xticks(range(len(layers)))
axes[1, 1].set_xticklabels([f"L{l}" for l in layers])
axes[1, 1].set_ylabel("AUROC")
axes[1, 1].set_title("AUROC by Layer")
axes[1, 1].axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Random')

plt.tight_layout()
plt.savefig("calibrated_probe_results.png", dpi=150)
plt.show()

In [None]:
# Reliability diagram for best layer
best_layer = min(results.keys(), key=lambda l: results[l]["brier"])
best_conf = results[best_layer]["val_conf"]
y_val = correctness[val_idx]

# Compute calibration curve
n_bins = 10
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_centers = []
bin_accuracies = []
bin_counts = []

for i in range(n_bins):
    mask = (best_conf >= bin_boundaries[i]) & (best_conf < bin_boundaries[i+1])
    if mask.sum() > 0:
        bin_centers.append((bin_boundaries[i] + bin_boundaries[i+1]) / 2)
        bin_accuracies.append(y_val[mask].mean())
        bin_counts.append(mask.sum())

fig, ax = plt.subplots(figsize=(8, 6))

# Perfect calibration line
ax.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')

# Actual calibration
ax.bar(bin_centers, bin_accuracies, width=0.08, alpha=0.7, label='Probe calibration')

ax.set_xlabel('Predicted Confidence')
ax.set_ylabel('Actual Accuracy')
ax.set_title(f'Reliability Diagram - Layer {best_layer}\n(Brier={results[best_layer]["brier"]:.3f}, ECE={results[best_layer]["ece"]:.3f})')
ax.legend()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig("reliability_diagram.png", dpi=150)
plt.show()

In [None]:
# Summary
print("\n" + "="*60)
print("CALIBRATED PROBE SUMMARY")
print("="*60)

best_brier_layer = min(results.keys(), key=lambda l: results[l]["brier"])
best_auroc_layer = max(results.keys(), key=lambda l: results[l]["auroc"])
best_ece_layer = min(results.keys(), key=lambda l: results[l]["ece"])

print(f"\nModel base accuracy: {accuracy:.1%}")
print(f"\nBest Brier Score:  Layer {best_brier_layer} ({results[best_brier_layer]['brier']:.3f})")
print(f"Best AUROC:        Layer {best_auroc_layer} ({results[best_auroc_layer]['auroc']:.3f})")
print(f"Best ECE:          Layer {best_ece_layer} ({results[best_ece_layer]['ece']:.3f})")

print("\n" + "-"*60)
print("Layer-by-Layer Results:")
print("-"*60)
for layer in sorted(results.keys()):
    r = results[layer]
    print(f"Layer {layer:2d}: Acc={r['accuracy']:.3f} | Brier={r['brier']:.3f} | ECE={r['ece']:.3f} | AUROC={r['auroc']:.3f}")

In [None]:
# Analyze confidence distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

best_conf = results[best_brier_layer]["val_conf"]
y_val = correctness[val_idx]

# Confidence distribution by correctness
axes[0].hist(best_conf[y_val == 1], bins=20, alpha=0.7, label='Correct', color='green')
axes[0].hist(best_conf[y_val == 0], bins=20, alpha=0.7, label='Wrong', color='red')
axes[0].set_xlabel('Predicted Confidence')
axes[0].set_ylabel('Count')
axes[0].set_title(f'Confidence Distribution (Layer {best_brier_layer})')
axes[0].legend()

# Confidence vs actual
axes[1].scatter(best_conf[y_val == 1], np.ones_like(best_conf[y_val == 1]) + np.random.randn(sum(y_val == 1)) * 0.05, 
                alpha=0.5, label='Correct', color='green', s=20)
axes[1].scatter(best_conf[y_val == 0], np.zeros_like(best_conf[y_val == 0]) + np.random.randn(sum(y_val == 0)) * 0.05,
                alpha=0.5, label='Wrong', color='red', s=20)
axes[1].set_xlabel('Predicted Confidence')
axes[1].set_ylabel('Actual Correctness')
axes[1].set_title('Confidence vs Correctness')
axes[1].legend()

plt.tight_layout()
plt.savefig("confidence_analysis.png", dpi=150)
plt.show()

In [None]:
# Download results
from google.colab import files
files.download("calibrated_probe_results.png")
files.download("reliability_diagram.png")
files.download("confidence_analysis.png")