# Hierarchical Multi-Scale Probe Analysis

This notebook analyzes the results from the hierarchical probing experiment and compares
them against baseline methods (linear and MLP probes).

## Contents
1. Load experiment results and probes
2. Calibration analysis and reliability diagrams
3. Baseline comparisons
4. Hierarchical level analysis (token, span, semantic, global)
5. Performance by subject/category
6. Uncertainty distribution analysis

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from src.utils import load_config
from src.models import ModelLoader, HiddenStateExtractor
from src.data import MMLUDataset
from src.probes import HierarchicalProbe, LinearProbe, MLPProbe
from src.evaluation import CalibrationMetrics
from src.evaluation.calibration import plot_reliability_diagram, plot_roc_curve

# Configure plotting
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

%matplotlib inline
%load_ext autoreload
%autoreload 2

## 1. Load Experiment Results

In [None]:
# Specify experiment output directory
experiment_dir = project_root / "outputs" / "hierarchical" / "llama-3.1-8b-hierarchical-probe"

# Load configuration
config = load_config(project_root / "configs" / "hierarchical_probe.yaml")
print(f"Experiment: {config.experiment.name}")
print(f"Model: {config.model.name}")
print(f"Dataset: {config.data.dataset}")

In [None]:
# Load saved probes
hierarchical_probe_path = experiment_dir / "hierarchical" / "probe.pt"
linear_probe_path = experiment_dir / "linear" / "probe.pt"
mlp_probe_path = experiment_dir / "mlp" / "probe.pt"

# Check which probes exist
print(f"Hierarchical probe exists: {hierarchical_probe_path.exists()}")
print(f"Linear probe exists: {linear_probe_path.exists()}")
print(f"MLP probe exists: {mlp_probe_path.exists()}")

## 2. Load Dataset and Generate Predictions

In [None]:
# Load MMLU dataset
dataset = MMLUDataset(
    split="test",
    subjects=config.data.get("subjects", None),
    max_samples=config.data.get("num_samples", None),
)

print(f"Loaded {len(dataset)} examples")

# Get statistics
stats = dataset.get_statistics()
print(f"\nCategories: {list(stats['category_counts'].keys())}")
print(f"Number of subjects: {stats['num_subjects']}")

In [None]:
# Plot category distribution
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
categories = list(stats['category_counts'].keys())
counts = list(stats['category_counts'].values())

ax.bar(categories, counts)
ax.set_xlabel("Category")
ax.set_ylabel("Number of Examples")
ax.set_title("MMLU Dataset Distribution by Category")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 3. Calibration Analysis

Visualize calibration quality using reliability diagrams.

In [None]:
# Load test predictions (you'll need to save these during experiment)
# For now, we'll demonstrate with synthetic data

# TODO: Load actual predictions from experiment
# hierarchical_confidences = np.load(experiment_dir / "hierarchical" / "test_confidences.npy")
# linear_confidences = np.load(experiment_dir / "linear" / "test_confidences.npy")
# mlp_confidences = np.load(experiment_dir / "mlp" / "test_confidences.npy")
# test_labels = np.load(experiment_dir / "test_labels.npy")

print("NOTE: Load actual prediction data from the experiment directory")

In [None]:
# Plot reliability diagrams for comparison
# Uncomment when you have actual data

# fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# # Hierarchical
# plot_reliability_diagram(
#     (hierarchical_confidences > 0.5).astype(int),
#     hierarchical_confidences,
#     test_labels,
#     num_bins=10,
#     ax=axes[0]
# )
# axes[0].set_title("Hierarchical Probe")

# # Linear
# plot_reliability_diagram(
#     (linear_confidences > 0.5).astype(int),
#     linear_confidences,
#     test_labels,
#     num_bins=10,
#     ax=axes[1]
# )
# axes[1].set_title("Linear Probe")

# # MLP
# plot_reliability_diagram(
#     (mlp_confidences > 0.5).astype(int),
#     mlp_confidences,
#     test_labels,
#     num_bins=10,
#     ax=axes[2]
# )
# axes[2].set_title("MLP Probe")

# plt.tight_layout()
# plt.show()

## 4. Metrics Comparison

Compare ECE, Brier, AUROC, and Accuracy across methods.

In [None]:
# Example results (replace with actual results)
results = {
    "Method": ["Hierarchical", "Linear", "MLP"],
    "ECE": [0.0450, 0.0820, 0.0680],  # Example values
    "Brier": [0.1200, 0.1580, 0.1420],
    "AUROC": [0.8650, 0.8120, 0.8350],
    "Accuracy": [0.8420, 0.8200, 0.8310],
    "Params": [450_000, 4_100, 2_100_000],
}

df = pd.DataFrame(results)
print("\nPerformance Comparison:")
print(df.to_string(index=False))

In [None]:
# Visualize metrics comparison
fig, axes = plt.subplots(1, 4, figsize=(18, 4))

metrics = ["ECE", "Brier", "AUROC", "Accuracy"]
colors = ['red', 'blue', 'green']

for i, metric in enumerate(metrics):
    axes[i].bar(df["Method"], df[metric], color=colors)
    axes[i].set_ylabel(metric)
    axes[i].set_title(f"{metric} Comparison")
    axes[i].tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for j, v in enumerate(df[metric]):
        axes[i].text(j, v + 0.01, f"{v:.3f}", ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 5. Hierarchical Level Analysis

Analyze confidence scores at different levels of the hierarchy.

In [None]:
# TODO: Analyze intermediate predictions from hierarchical probe
# This requires saving intermediate predictions during inference

# Example: Load intermediate predictions
# token_confidences = np.load(experiment_dir / "hierarchical" / "token_confidences.npy")
# span_confidences = np.load(experiment_dir / "hierarchical" / "span_confidences.npy")
# semantic_confidences = np.load(experiment_dir / "hierarchical" / "semantic_confidences.npy")
# global_confidences = np.load(experiment_dir / "hierarchical" / "global_confidences.npy")

print("NOTE: Analyze intermediate predictions from hierarchical levels")

In [None]:
# Visualize confidence distributions at each level
# Uncomment when you have actual data

# fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# levels = [
#     ("Token", token_confidences.mean(axis=1)),
#     ("Span", span_confidences),
#     ("Semantic", semantic_confidences),
#     ("Global", global_confidences),
# ]

# for idx, (level_name, confidences) in enumerate(levels):
#     ax = axes[idx // 2, idx % 2]
#     ax.hist(confidences, bins=50, alpha=0.7, edgecolor='black')
#     ax.set_xlabel("Confidence")
#     ax.set_ylabel("Frequency")
#     ax.set_title(f"{level_name}-Level Confidence Distribution")
#     ax.axvline(confidences.mean(), color='red', linestyle='--', 
#                label=f'Mean: {confidences.mean():.3f}')
#     ax.legend()

# plt.tight_layout()
# plt.show()

## 6. Performance by Subject/Category

Analyze how well the hierarchical probe performs across different subjects.

In [None]:
# TODO: Break down performance by MMLU category
# Group test examples by category and compute metrics separately

# Example structure:
# for category in ["STEM", "Humanities", "Social Sciences", "Other"]:
#     category_examples = dataset.get_by_category(category)
#     # Compute metrics for this category

print("NOTE: Analyze performance breakdown by MMLU categories")

## 7. ECE Improvement Analysis

In [None]:
# Calculate improvement percentages
linear_ece = 0.0820  # Example value
hierarchical_ece = 0.0450  # Example value

improvement = (linear_ece - hierarchical_ece) / linear_ece * 100

print(f"ECE Improvement over Linear Baseline: {improvement:.1f}%")
print(f"\nLinear ECE: {linear_ece:.4f}")
print(f"Hierarchical ECE: {hierarchical_ece:.4f}")
print(f"Absolute reduction: {linear_ece - hierarchical_ece:.4f}")

## 8. Selective Prediction Analysis

Analyze coverage-accuracy tradeoff for selective prediction.

In [None]:
# TODO: Implement selective prediction curves
# Sort by confidence, progressively exclude low-confidence predictions
# Plot accuracy vs. coverage

print("NOTE: Implement selective prediction analysis")

## 9. Model Complexity vs Performance

Visualize the tradeoff between model complexity and performance.

In [None]:
# Plot ECE vs. number of parameters
fig, ax = plt.subplots(figsize=(10, 6))

methods = df["Method"]
params = df["Params"]
ece_values = df["ECE"]

scatter = ax.scatter(params, ece_values, s=200, alpha=0.6)

# Add labels
for i, method in enumerate(methods):
    ax.annotate(method, (params[i], ece_values[i]), 
                xytext=(10, 10), textcoords='offset points',
                fontsize=12, fontweight='bold')

ax.set_xlabel("Number of Parameters")
ax.set_ylabel("Expected Calibration Error (ECE)")
ax.set_title("Model Complexity vs Calibration Performance")
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 10. Conclusions

Summary of findings:

1. **Calibration Quality**: The hierarchical probe achieves X% improvement in ECE over the linear baseline
2. **Multi-Scale Benefits**: Different levels of the hierarchy capture complementary uncertainty signals
3. **Complexity Tradeoff**: Hierarchical probe offers good balance between performance and model size
4. **Domain Generalization**: Performance analysis across MMLU categories shows...

### Next Steps
- Test on other datasets (TriviaQA, GSM8K)
- Experiment with different model architectures
- Analyze failure cases
- Compare with CCPS and semantic entropy methods