# NB11: Retrieval Quality Diagnostics

**Question:** Is performance bottlenecked by retrieval or generation?

This notebook diagnoses retrieval quality using saved predictions:
1. **Answer-in-context rates** by retriever, model, dataset
2. **Retrieval vs Generation bottleneck** quadrant analysis
3. **Reranking effectiveness** (rank changes, score distributions)
4. **Correlation** between retrieval quality and F1
5. **Per-retriever diagnostic** deep dive

In [None]:
import json
import re
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats as scipy_stats

from analysis_utils import (
    load_all_results, setup_plotting, parse_experiment_name,
    _enrich_from_metadata, SKIP_DIRS,
    PRIMARY_METRIC, BROKEN_MODELS, MODEL_TIER,
)

setup_plotting()
STUDY_PATH = Path("../outputs/smart_retrieval_slm")

# Load experiment-level results for filtering
df_all = load_all_results(STUDY_PATH)
df = df_all[~df_all['model_short'].isin(BROKEN_MODELS)].copy()
rag_df = df[df['exp_type'] == 'rag'].copy()
print(f"RAG experiments to analyze: {len(rag_df)}")

In [None]:
# ---- Retrieval analysis functions (adapted from scripts/analyze_retrieval.py) ----

def normalize_text(text: str) -> str:
    """Normalize text for fuzzy matching."""
    text = text.lower().strip()
    text = re.sub(r"[^\w\s]", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text


def answer_in_text(expected, text: str) -> bool:
    """Check if any expected answer appears in text."""
    if not text or not expected:
        return False
    text_norm = normalize_text(text)
    answers = expected if isinstance(expected, list) else [expected]
    for ans in answers:
        ans_norm = normalize_text(str(ans))
        if len(ans_norm) >= 2 and ans_norm in text_norm:
            return True
    return False


def is_correct(prediction: str, expected: list, threshold: float = 0.5) -> bool:
    """Check if prediction matches any expected answer (fuzzy)."""
    if not prediction or not expected:
        return False
    pred_norm = normalize_text(prediction)
    for exp in expected:
        exp_norm = normalize_text(str(exp))
        if not exp_norm:
            continue
        if exp_norm in pred_norm or pred_norm in exp_norm:
            return True
        pred_words = set(pred_norm.split())
        exp_words = set(exp_norm.split())
        if exp_words and pred_words:
            overlap = len(pred_words & exp_words) / len(exp_words)
            if overlap >= threshold:
                return True
    return False


print("Analysis functions defined.")

In [None]:
# ---- Load prediction-level retrieval data for all RAG experiments ----

rag_experiment_names = set(rag_df['name'].values)
retrieval_rows = []
doc_rows = []  # Per-document data for reranking analysis

for exp_dir in STUDY_PATH.iterdir():
    if not exp_dir.is_dir() or exp_dir.name in SKIP_DIRS:
        continue
    if exp_dir.name.startswith('.') or exp_dir.name.startswith('_'):
        continue

    pred_file = exp_dir / "predictions.json"
    meta_file = exp_dir / "metadata.json"
    if not pred_file.exists():
        continue

    # Parse experiment config
    config = parse_experiment_name(exp_dir.name)
    if meta_file.exists():
        with open(meta_file) as f:
            _enrich_from_metadata(config, json.load(f))

    if config.get('exp_type') == 'direct':
        continue
    if config.get('model_short') in BROKEN_MODELS:
        continue

    with open(pred_file) as f:
        data = json.load(f)

    preds = data.get('predictions', [])
    if not preds:
        continue

    for p in preds:
        expected = p.get('expected', [])
        prediction = p.get('prediction', '')
        prompt = p.get('prompt', '')
        retrieved_docs = p.get('retrieved_docs', [])
        metrics = p.get('metrics', {})

        if not expected:
            continue

        # Check answer in context (using prompt text)
        has_answer_in_prompt = answer_in_text(expected, prompt)

        # Check answer in individual documents
        answer_in_any_doc = False
        answer_doc_rank = None
        for doc in retrieved_docs:
            content = doc.get('content', '')
            if answer_in_text(expected, content):
                answer_in_any_doc = True
                rank = doc.get('rank', 999)
                if answer_doc_rank is None or rank < answer_doc_rank:
                    answer_doc_rank = rank

        got_correct = is_correct(prediction, expected)
        f1_score = metrics.get('f1', np.nan)

        retrieval_rows.append({
            'experiment': exp_dir.name,
            'idx': p.get('idx'),
            'model_short': config.get('model_short'),
            'dataset': config.get('dataset'),
            'retriever': config.get('retriever'),
            'retriever_type': config.get('retriever_type'),
            'reranker': config.get('reranker', 'none'),
            'agent_type': config.get('agent_type'),
            'top_k': config.get('top_k'),
            'query_transform': config.get('query_transform', 'none'),
            'has_answer_in_prompt': has_answer_in_prompt,
            'answer_in_any_doc': answer_in_any_doc,
            'answer_doc_rank': answer_doc_rank,
            'got_correct': got_correct,
            'f1': f1_score,
            'n_docs': len(retrieved_docs),
        })

        # Per-document data for reranking analysis
        for doc in retrieved_docs:
            doc_rows.append({
                'experiment': exp_dir.name,
                'idx': p.get('idx'),
                'reranker': config.get('reranker', 'none'),
                'rank': doc.get('rank'),
                'score': doc.get('score'),
                'retrieval_score': doc.get('retrieval_score'),
                'retrieval_rank': doc.get('retrieval_rank'),
                'rerank_score': doc.get('rerank_score'),
                'has_answer': answer_in_text(expected, doc.get('content', '')),
            })

ret_df = pd.DataFrame(retrieval_rows)
doc_df = pd.DataFrame(doc_rows) if doc_rows else pd.DataFrame()

print(f"Loaded {len(ret_df)} question-level retrieval records from {ret_df['experiment'].nunique()} experiments")
if not doc_df.empty:
    print(f"Loaded {len(doc_df)} document-level records")
    print(f"Documents with rerank_score: {doc_df['rerank_score'].notna().sum()}")

## 1. Answer-in-Context Rates

What fraction of questions have the answer in the retrieved context?

In [None]:
if not ret_df.empty:
    # Overall retrieval recall (dataset-stratified)
    if ret_df['dataset'].nunique() > 1:
        overall_recall = ret_df.groupby('dataset')['has_answer_in_prompt'].mean().mean()
        print(f"Overall answer-in-context rate (dataset-stratified): {overall_recall:.1%}")
    else:
        overall_recall = ret_df['has_answer_in_prompt'].mean()
        print(f"Overall answer-in-context rate: {overall_recall:.1%}")
    print()

    # By retriever (dataset-stratified)
    if 'retriever_type' in ret_df.columns:
        if ret_df['dataset'].nunique() > 1:
            by_retriever = (
                ret_df.groupby(['retriever_type', 'dataset'])
                .agg(recall=('has_answer_in_prompt', 'mean'), n=('has_answer_in_prompt', 'count'))
                .reset_index()
                .groupby('retriever_type')
                .agg(recall=('recall', 'mean'), n=('n', 'sum'))
                .sort_values('recall', ascending=False)
            )
        else:
            by_retriever = ret_df.groupby('retriever_type').agg(
                recall=('has_answer_in_prompt', 'mean'),
                n=('has_answer_in_prompt', 'count'),
            ).sort_values('recall', ascending=False)
        print("By retriever type (dataset-stratified):")
        display(by_retriever.round(3))

    # By dataset (this is naturally per-dataset)
    by_dataset = ret_df.groupby('dataset').agg(
        recall=('has_answer_in_prompt', 'mean'),
        n=('has_answer_in_prompt', 'count'),
    ).sort_values('recall', ascending=False)
    print("\nBy dataset:")
    display(by_dataset.round(3))

    # By reranker (dataset-stratified)
    if ret_df['dataset'].nunique() > 1:
        by_reranker = (
            ret_df.groupby(['reranker', 'dataset'])
            .agg(recall=('has_answer_in_prompt', 'mean'), n=('has_answer_in_prompt', 'count'))
            .reset_index()
            .groupby('reranker')
            .agg(recall=('recall', 'mean'), n=('n', 'sum'))
            .sort_values('recall', ascending=False)
        )
    else:
        by_reranker = ret_df.groupby('reranker').agg(
            recall=('has_answer_in_prompt', 'mean'),
            n=('has_answer_in_prompt', 'count'),
        ).sort_values('recall', ascending=False)
    print("\nBy reranker (dataset-stratified):")
    display(by_reranker.round(3))

    # Heatmap: retriever_type x dataset (already naturally stratified)
    pivot = ret_df.groupby(['retriever_type', 'dataset'])['has_answer_in_prompt'].mean().unstack()
    if not pivot.empty and pivot.shape[0] > 1:
        fig, ax = plt.subplots(figsize=(8, 5))
        sns.heatmap(pivot, annot=True, fmt='.1%', cmap='RdYlGn', vmin=0, vmax=1, ax=ax)
        ax.set_title('Answer-in-Context Rate by Retriever Type x Dataset')
        plt.tight_layout()
        plt.show()

## 2. Retrieval vs Generation Bottleneck

Quadrant analysis:
- **Correct + Context**: Retrieval and generation both succeeded
- **Wrong + Context**: Had the answer but generation failed
- **Correct + No Context**: Got lucky (or parametric knowledge)
- **Wrong + No Context**: Retrieval failure (answer not retrieved)

In [None]:
if not ret_df.empty:
    # Classify into quadrants
    ret_df['quadrant'] = 'unknown'
    mask_ctx = ret_df['has_answer_in_prompt']
    mask_correct = ret_df['got_correct']

    ret_df.loc[mask_ctx & mask_correct, 'quadrant'] = 'Correct + Context'
    ret_df.loc[mask_ctx & ~mask_correct, 'quadrant'] = 'Wrong + Context'
    ret_df.loc[~mask_ctx & mask_correct, 'quadrant'] = 'Correct + No Context'
    ret_df.loc[~mask_ctx & ~mask_correct, 'quadrant'] = 'Wrong + No Context'

    # Dataset-stratified overall distribution
    if ret_df['dataset'].nunique() > 1:
        print("Overall Quadrant Distribution (dataset-stratified):")
        per_ds_quad = (
            ret_df.groupby('dataset')['quadrant']
            .value_counts(normalize=True)
            .rename('pct')
            .reset_index()
        )
        strat_quad = per_ds_quad.groupby('quadrant')['pct'].mean()
        for q in ['Correct + Context', 'Wrong + Context', 'Correct + No Context', 'Wrong + No Context']:
            if q in strat_quad.index:
                print(f"  {q:<25s}: {strat_quad[q]:.1%}")
    else:
        quadrant_pcts = ret_df['quadrant'].value_counts(normalize=True)
        print("Overall Quadrant Distribution:")
        for q in ['Correct + Context', 'Wrong + Context', 'Correct + No Context', 'Wrong + No Context']:
            if q in quadrant_pcts.index:
                print(f"  {q:<25s}: {quadrant_pcts[q]:.1%}")

    # Bottleneck identification (dataset-stratified)
    if ret_df['dataset'].nunique() > 1:
        ds_recalls = ret_df.groupby('dataset')['has_answer_in_prompt'].mean()
        retrieval_recall = ds_recalls.mean()
        ds_gen = []
        for ds in ret_df['dataset'].unique():
            ds_mask = (ret_df['dataset'] == ds) & ret_df['has_answer_in_prompt']
            if ds_mask.sum() > 0:
                ds_gen.append(ret_df.loc[ds_mask, 'got_correct'].mean())
        gen_given_ctx = np.mean(ds_gen) if ds_gen else 0
    else:
        retrieval_recall = mask_ctx.mean()
        gen_given_ctx = ret_df.loc[mask_ctx, 'got_correct'].mean() if mask_ctx.sum() > 0 else 0

    print(f"\nRetrieval Recall (dataset-stratified): {retrieval_recall:.1%}")
    print(f"Generation|Context (dataset-stratified): {gen_given_ctx:.1%}")
    if retrieval_recall < 0.5:
        print("Bottleneck: RETRIEVAL")
    elif gen_given_ctx < 0.5:
        print("Bottleneck: GENERATION")
    else:
        print("Bottleneck: BALANCED")

    # Per-model breakdown (dataset-stratified)
    print("\nPer-model bottleneck (dataset-stratified):")
    for model in sorted(ret_df['model_short'].unique()):
        m_df = ret_df[ret_df['model_short'] == model]
        if m_df['dataset'].nunique() > 1:
            m_recall = m_df.groupby('dataset')['has_answer_in_prompt'].mean().mean()
            ds_gen_rates = []
            for ds in m_df['dataset'].unique():
                ds_m = m_df[(m_df['dataset'] == ds) & m_df['has_answer_in_prompt']]
                if len(ds_m) > 0:
                    ds_gen_rates.append(ds_m['got_correct'].mean())
            m_gen = np.mean(ds_gen_rates) if ds_gen_rates else 0
        else:
            m_ctx = m_df['has_answer_in_prompt']
            m_recall = m_ctx.mean()
            m_gen = m_df.loc[m_ctx, 'got_correct'].mean() if m_ctx.sum() > 0 else 0
        bottleneck = 'RETRIEVAL' if m_recall < 0.5 else 'GENERATION' if m_gen < 0.5 else 'BALANCED'
        print(f"  {model:<16s}: recall={m_recall:.1%}, gen|ctx={m_gen:.1%}  -> {bottleneck}")

In [None]:
# Quadrant breakdown by dataset and retriever
if not ret_df.empty:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # By dataset
    ds_quad = ret_df.groupby('dataset')['quadrant'].value_counts(normalize=True).unstack(fill_value=0)
    quad_order = ['Correct + Context', 'Correct + No Context', 'Wrong + Context', 'Wrong + No Context']
    quad_order = [q for q in quad_order if q in ds_quad.columns]
    colors = ['#2ecc71', '#27ae60', '#e67e22', '#e74c3c']
    ds_quad[quad_order].plot(kind='bar', stacked=True, ax=axes[0],
                             color=colors[:len(quad_order)])
    axes[0].set_title('Outcome Quadrants by Dataset')
    axes[0].set_ylabel('Proportion')
    axes[0].legend(fontsize=8, loc='upper right')
    axes[0].tick_params(axis='x', rotation=0)

    # By retriever type
    if 'retriever_type' in ret_df.columns and ret_df['retriever_type'].nunique() > 1:
        rt_quad = ret_df.groupby('retriever_type')['quadrant'].value_counts(normalize=True).unstack(fill_value=0)
        rt_quad[quad_order].plot(kind='bar', stacked=True, ax=axes[1],
                                 color=colors[:len(quad_order)])
        axes[1].set_title('Outcome Quadrants by Retriever Type')
        axes[1].set_ylabel('Proportion')
        axes[1].legend(fontsize=8, loc='upper right')
        axes[1].tick_params(axis='x', rotation=0)

    plt.tight_layout()
    plt.show()

## 3. Reranking Effectiveness

How much does reranking improve the position of answer-bearing documents?

In [None]:
if not doc_df.empty and doc_df['retrieval_rank'].notna().sum() > 0:
    # Filter to documents that have both original and final rank
    reranked = doc_df[doc_df['retrieval_rank'].notna() & doc_df['rank'].notna()].copy()
    reranked['rank_change'] = reranked['retrieval_rank'] - reranked['rank']  # positive = improved

    # Add dataset info from ret_df for stratification
    if 'dataset' not in reranked.columns:
        exp_to_ds = ret_df.groupby('experiment')['dataset'].first()
        reranked['dataset'] = reranked['experiment'].map(exp_to_ds)

    if not reranked.empty:
        answer_docs = reranked[reranked['has_answer']]
        non_answer_docs = reranked[~reranked['has_answer']]

        print(f"Documents with rank data: {len(reranked):,}")
        print(f"  Answer-bearing: {len(answer_docs):,}")
        print(f"  Non-answer: {len(non_answer_docs):,}")
        print()

        if not answer_docs.empty:
            # Dataset-stratified reranker stats
            if 'dataset' in answer_docs.columns and answer_docs['dataset'].nunique() > 1:
                per_ds = (
                    answer_docs.groupby(['reranker', 'dataset'])
                    .agg(
                        mean_rank_change=('rank_change', 'mean'),
                        median_rank_change=('rank_change', 'median'),
                        pct_improved=('rank_change', lambda x: (x > 0).mean()),
                        mean_final_rank=('rank', 'mean'),
                        n=('rank_change', 'count'),
                    )
                    .reset_index()
                )
                by_reranker = (
                    per_ds.groupby('reranker')
                    .agg(
                        mean_rank_change=('mean_rank_change', 'mean'),
                        median_rank_change=('median_rank_change', 'mean'),
                        pct_improved=('pct_improved', 'mean'),
                        mean_final_rank=('mean_final_rank', 'mean'),
                        n=('n', 'sum'),
                    )
                    .round(3)
                )
            else:
                by_reranker = answer_docs.groupby('reranker').agg(
                    mean_rank_change=('rank_change', 'mean'),
                    median_rank_change=('rank_change', 'median'),
                    pct_improved=('rank_change', lambda x: (x > 0).mean()),
                    mean_final_rank=('rank', 'mean'),
                    n=('rank_change', 'count'),
                ).round(3)

            print("Reranking impact on ANSWER-BEARING documents (dataset-stratified):")
            display(by_reranker)

            # Distribution of rank changes
            fig, axes = plt.subplots(1, 2, figsize=(14, 5))

            for rr in sorted(answer_docs['reranker'].unique()):
                sub = answer_docs[answer_docs['reranker'] == rr]
                axes[0].hist(sub['rank_change'], bins=20, alpha=0.5, label=rr, edgecolor='black')
            axes[0].axvline(x=0, color='red', linestyle='--', linewidth=1)
            axes[0].set_xlabel('Rank Change (positive = improved)')
            axes[0].set_ylabel('Count')
            axes[0].set_title('Rank Change for Answer-Bearing Documents')
            axes[0].legend()

            # Before vs After rank scatter
            sample = answer_docs.sample(min(2000, len(answer_docs)), random_state=42)
            for rr in sorted(sample['reranker'].unique()):
                sub = sample[sample['reranker'] == rr]
                axes[1].scatter(sub['retrieval_rank'], sub['rank'],
                               s=10, alpha=0.3, label=rr)
            axes[1].plot([0, 25], [0, 25], 'k--', alpha=0.5, label='No change')
            axes[1].set_xlabel('Original Rank (before reranking)')
            axes[1].set_ylabel('Final Rank (after reranking)')
            axes[1].set_title('Reranking Effect on Answer Documents')
            axes[1].legend()

            plt.tight_layout()
            plt.show()
else:
    print("No document-level rank data available for reranking analysis.")
    print("This requires retrieved_docs with retrieval_rank and rank fields in predictions.json.")

## 4. Retrieval Quality vs F1 Correlation

Does higher retrieval recall correlate with higher F1 at the experiment level?

In [None]:
if not ret_df.empty:
    # Aggregate per experiment: retrieval recall and mean F1
    exp_stats = ret_df.groupby('experiment').agg(
        retrieval_recall=('has_answer_in_prompt', 'mean'),
        gen_given_ctx=('got_correct', lambda x: x[ret_df.loc[x.index, 'has_answer_in_prompt']].mean()
                        if ret_df.loc[x.index, 'has_answer_in_prompt'].sum() > 0 else np.nan),
        mean_f1=('f1', 'mean'),
        accuracy=('got_correct', 'mean'),
        n_questions=('idx', 'count'),
        model_short=('model_short', 'first'),
        dataset=('dataset', 'first'),
        retriever_type=('retriever_type', 'first'),
        reranker=('reranker', 'first'),
    ).dropna(subset=['mean_f1'])

    if len(exp_stats) >= 5:
        # Per-dataset correlation, then report both per-dataset and combined
        print("Retrieval Recall vs F1 Correlation (per-dataset):")
        per_ds_corrs = []
        for ds in sorted(exp_stats['dataset'].unique()):
            sub = exp_stats[exp_stats['dataset'] == ds]
            if len(sub) >= 5:
                r_ds, p_ds = scipy_stats.pearsonr(sub['retrieval_recall'], sub['mean_f1'])
                rho_ds, prho_ds = scipy_stats.spearmanr(sub['retrieval_recall'], sub['mean_f1'])
                per_ds_corrs.append({'dataset': ds, 'pearson_r': r_ds, 'spearman_rho': rho_ds, 'n': len(sub)})
                print(f"  {ds}: Pearson r={r_ds:.3f}, Spearman rho={rho_ds:.3f} (n={len(sub)})")

        # Overall (pooled â€” shown for reference but dataset scatter colors make it interpretable)
        r, p = scipy_stats.pearsonr(exp_stats['retrieval_recall'], exp_stats['mean_f1'])
        rho, p_rho = scipy_stats.spearmanr(exp_stats['retrieval_recall'], exp_stats['mean_f1'])
        print(f"\n  Pooled (all datasets): Pearson r={r:.3f}, Spearman rho={rho:.3f}")
        if per_ds_corrs:
            mean_r = np.mean([c['pearson_r'] for c in per_ds_corrs])
            mean_rho = np.mean([c['spearman_rho'] for c in per_ds_corrs])
            print(f"  Mean across datasets: Pearson r={mean_r:.3f}, Spearman rho={mean_rho:.3f}")

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # Scatter: retrieval recall vs F1 (colored by dataset)
        for ds in sorted(exp_stats['dataset'].unique()):
            sub = exp_stats[exp_stats['dataset'] == ds]
            axes[0].scatter(sub['retrieval_recall'], sub['mean_f1'],
                           s=30, alpha=0.5, label=ds)

        # Per-dataset regression lines
        for ds in sorted(exp_stats['dataset'].unique()):
            sub = exp_stats[exp_stats['dataset'] == ds]
            if len(sub) >= 5:
                x = sub['retrieval_recall'].values
                y = sub['mean_f1'].values
                z = np.polyfit(x, y, 1)
                p_line = np.poly1d(z)
                x_sorted = np.sort(x)
                axes[0].plot(x_sorted, p_line(x_sorted), '--', alpha=0.4)

        axes[0].set_xlabel('Retrieval Recall (answer-in-context rate)')
        axes[0].set_ylabel('Mean F1')
        axes[0].set_title('Retrieval Quality vs Generation Quality\n(per-dataset regression lines)')
        axes[0].legend()
        axes[0].grid(alpha=0.3)

        # Scatter: retrieval recall vs gen|ctx
        gen_stats = exp_stats.dropna(subset=['gen_given_ctx'])
        if not gen_stats.empty:
            for ds in sorted(gen_stats['dataset'].unique()):
                sub = gen_stats[gen_stats['dataset'] == ds]
                axes[1].scatter(sub['retrieval_recall'], sub['gen_given_ctx'],
                               s=30, alpha=0.5, label=ds)
            axes[1].set_xlabel('Retrieval Recall')
            axes[1].set_ylabel('Generation|Context (accuracy given answer in context)')
            axes[1].set_title('Retrieval vs Generation Quality')
            axes[1].axhline(y=0.5, color='red', linestyle='--', alpha=0.3)
            axes[1].axvline(x=0.5, color='red', linestyle='--', alpha=0.3)
            axes[1].legend()
            axes[1].grid(alpha=0.3)

        plt.tight_layout()
        plt.show()

## 5. Per-Retriever Deep Dive

Detailed retrieval diagnostics for each retriever.

In [None]:
if not ret_df.empty:
    # Dataset-stratified per-retriever stats
    if ret_df['dataset'].nunique() > 1:
        per_ds_stats = (
            ret_df.groupby(['retriever_type', 'dataset'])
            .agg(
                retrieval_recall=('has_answer_in_prompt', 'mean'),
                accuracy=('got_correct', 'mean'),
                mean_f1=('f1', lambda x: x.dropna().mean()),
                n_questions=('idx', 'count'),
                n_experiments=('experiment', 'nunique'),
            )
            .reset_index()
        )
        retriever_stats = (
            per_ds_stats.groupby('retriever_type')
            .agg(
                retrieval_recall=('retrieval_recall', 'mean'),
                accuracy=('accuracy', 'mean'),
                mean_f1=('mean_f1', 'mean'),
                n_questions=('n_questions', 'sum'),
                n_experiments=('n_experiments', 'sum'),
            )
            .round(4)
        )
    else:
        retriever_stats = ret_df.groupby('retriever_type').agg(
            retrieval_recall=('has_answer_in_prompt', 'mean'),
            accuracy=('got_correct', 'mean'),
            mean_f1=('f1', lambda x: x.dropna().mean()),
            n_questions=('idx', 'count'),
            n_experiments=('experiment', 'nunique'),
        ).round(4)

    # Add generation|context rate (dataset-stratified)
    for rt in retriever_stats.index:
        rt_data = ret_df[ret_df['retriever_type'] == rt]
        if rt_data['dataset'].nunique() > 1:
            ds_gen_rates = []
            for ds in rt_data['dataset'].unique():
                ds_data = rt_data[(rt_data['dataset'] == ds) & rt_data['has_answer_in_prompt']]
                if len(ds_data) > 0:
                    ds_gen_rates.append(ds_data['got_correct'].mean())
            gen_rate = np.mean(ds_gen_rates) if ds_gen_rates else np.nan
        else:
            ctx_mask = rt_data['has_answer_in_prompt']
            gen_rate = rt_data.loc[ctx_mask, 'got_correct'].mean() if ctx_mask.sum() > 0 else np.nan
        retriever_stats.loc[rt, 'gen_given_ctx'] = gen_rate

    print("Per-Retriever Diagnostics (dataset-stratified):")
    display(retriever_stats.round(3))

    # Answer rank distribution (where in the top-K is the answer?)
    answer_rank_data = ret_df[ret_df['answer_doc_rank'].notna()].copy()
    if not answer_rank_data.empty:
        fig, ax = plt.subplots(figsize=(10, 5))
        for rt in sorted(answer_rank_data['retriever_type'].dropna().unique()):
            sub = answer_rank_data[answer_rank_data['retriever_type'] == rt]
            ranks = sub['answer_doc_rank']
            ax.hist(ranks, bins=range(1, int(ranks.max()) + 2), alpha=0.5,
                    label=f"{rt} (n={len(sub)}, median={ranks.median():.0f})",
                    edgecolor='black')

        ax.set_xlabel('Rank of First Answer-Bearing Document')
        ax.set_ylabel('Count')
        ax.set_title('Where Does the Answer Appear in Retrieved Documents?')
        ax.legend()
        ax.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

In [None]:
# Top-K sensitivity for retrieval recall (dataset-stratified)
if not ret_df.empty and 'top_k' in ret_df.columns and ret_df['top_k'].nunique() > 1:
    if ret_df['dataset'].nunique() > 1:
        # Per-(top_k, retriever_type, dataset) stats, then average across datasets
        per_ds = (
            ret_df.groupby(['top_k', 'retriever_type', 'dataset'])
            .agg(
                recall=('has_answer_in_prompt', 'mean'),
                mean_f1=('f1', lambda x: x.dropna().mean()),
                n=('idx', 'count'),
            )
            .reset_index()
        )
        topk_recall = (
            per_ds.groupby(['top_k', 'retriever_type'])
            .agg(recall=('recall', 'mean'), mean_f1=('mean_f1', 'mean'), n=('n', 'sum'))
            .reset_index()
        )
    else:
        topk_recall = ret_df.groupby(['top_k', 'retriever_type']).agg(
            recall=('has_answer_in_prompt', 'mean'),
            mean_f1=('f1', lambda x: x.dropna().mean()),
            n=('idx', 'count'),
        ).reset_index()

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    for rt in sorted(topk_recall['retriever_type'].unique()):
        sub = topk_recall[topk_recall['retriever_type'] == rt].sort_values('top_k')
        axes[0].plot(sub['top_k'], sub['recall'], 'o-', label=rt)
        axes[1].plot(sub['top_k'], sub['mean_f1'], 'o-', label=rt)

    axes[0].set_xlabel('Top-K')
    axes[0].set_ylabel('Retrieval Recall (dataset-stratified)')
    axes[0].set_title('Retrieval Recall vs Top-K')
    axes[0].legend()
    axes[0].grid(alpha=0.3)

    axes[1].set_xlabel('Top-K')
    axes[1].set_ylabel('Mean F1 (dataset-stratified)')
    axes[1].set_title('F1 vs Top-K (diminishing returns?)')
    axes[1].legend()
    axes[1].grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

## 6. Summary

Key retrieval diagnostics:
- **Answer-in-context rate**: What fraction of questions have the answer in retrieved docs
- **Bottleneck analysis**: Is performance limited by retrieval or generation
- **Reranking value**: How much do rerankers improve answer document positioning
- **Retrieval-F1 correlation**: How tightly does retrieval quality predict final F1
- **Per-retriever profiles**: Which retrievers work best for which datasets