# NB03: Component Analysis

**Question:** Which RAG knobs matter most? Optimal values? Interactions?

This notebook analyzes RAG component effects:
- Variance decomposition (which factors explain the most performance variance)
- Marginal effects of each component
- Prompt and top-K deep dives
- Interaction effects between components
- Optimal configurations

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from analysis_utils import (
    load_all_results, setup_plotting, identify_bottlenecks,
    compute_marginal_means, plot_component_effects,
    plot_interaction_heatmap, find_synergistic_combinations,
    weighted_mean_with_ci,
    PRIMARY_METRIC, BROKEN_MODELS,
)

setup_plotting()
STUDY_PATH = Path("../outputs/smart_retrieval_slm")

df_all = load_all_results(STUDY_PATH)
df = df_all[~df_all['model_short'].isin(BROKEN_MODELS)].copy()

# Focus on RAG experiments only
rag = df[df['exp_type'] == 'rag'].copy()
print(f"RAG experiments: {len(rag)} (from {len(df)} total, {len(df_all)} before broken-model filter)")

## 1. Variance Decomposition

The single most important thesis figure: which factors explain the most F1 variance?

In [None]:
bottlenecks = identify_bottlenecks(df, PRIMARY_METRIC)

if bottlenecks:
    print("Variance Explained by Factor (%)")
    print("=" * 50)
    for factor, pct in bottlenecks.items():
        bar = '#' * int(pct / 2)
        print(f"  {factor:<20s}: {pct:5.1f}%  {bar}")

    # Horizontal bar chart
    fig, ax = plt.subplots(figsize=(10, 5))
    factors = list(bottlenecks.keys())
    values = list(bottlenecks.values())
    colors = plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, len(factors)))
    ax.barh(factors[::-1], values[::-1], color=colors[::-1], edgecolor='black', linewidth=0.5)
    ax.set_xlabel('Variance Explained (%)')
    ax.set_title('RAG Component Variance Decomposition')
    ax.grid(axis='x', alpha=0.3)
    for i, v in enumerate(values[::-1]):
        ax.text(v + 0.3, i, f'{v:.1f}%', va='center', fontsize=10)
    plt.tight_layout()
    plt.show()

## 2. Marginal Effects

Marginal mean of each factor level, controlling for model and dataset.

In [None]:
factors = ['retriever_type', 'embedding_model', 'reranker', 'prompt',
           'query_transform', 'top_k']
# Filter to factors present with > 1 level
factors = [f for f in factors if f in rag.columns and rag[f].nunique() > 1]

n_factors = len(factors)
ncols = min(3, n_factors)
nrows = (n_factors + ncols - 1) // ncols

fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 5 * nrows))
if n_factors == 1:
    axes = np.array([axes])
axes = np.atleast_2d(axes)

for idx, factor in enumerate(factors):
    r, c = divmod(idx, ncols)
    ax = axes[r, c]
    plot_component_effects(df, factor, PRIMARY_METRIC, ax=ax)

# Hide unused axes
for idx in range(n_factors, nrows * ncols):
    r, c = divmod(idx, ncols)
    axes[r, c].set_visible(False)

plt.suptitle('Marginal Effects of RAG Components on F1', y=1.01, fontsize=14)
plt.tight_layout()
plt.show()

## 3. Prompt Deep Dive

In [None]:
# Prompt marginal means with CI
prompt_stats = weighted_mean_with_ci(rag, 'prompt', PRIMARY_METRIC)
print("Prompt Performance (mean F1 with 95% CI):")
display(prompt_stats.round(4))

# Prompt x Dataset heatmap
prompt_ds = rag.groupby(['prompt', 'dataset'])[PRIMARY_METRIC].mean().unstack()
if not prompt_ds.empty:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Bar chart
    x = range(len(prompt_stats))
    yerr_low = np.maximum(prompt_stats['mean'] - prompt_stats['ci_low'], 0)
    yerr_high = np.maximum(prompt_stats['ci_high'] - prompt_stats['mean'], 0)
    axes[0].bar(x, prompt_stats['mean'], yerr=[yerr_low, yerr_high],
               capsize=4, alpha=0.8, color='steelblue')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(prompt_stats['prompt'], rotation=45, ha='right')
    axes[0].set_ylabel('Mean F1')
    axes[0].set_title('Prompt Performance')
    axes[0].grid(axis='y', alpha=0.3)

    # Heatmap
    sns.heatmap(prompt_ds, annot=True, fmt='.3f', cmap='RdYlGn',
                ax=axes[1])
    axes[1].set_title('Prompt x Dataset (Mean F1)')

    plt.tight_layout()
    plt.show()

## 4. Top-K Sensitivity

In [None]:
if 'top_k' in rag.columns and rag['top_k'].nunique() > 1:
    topk_data = rag.dropna(subset=['top_k', PRIMARY_METRIC])
    retriever_types = sorted(topk_data['retriever_type'].dropna().unique())

    fig, ax = plt.subplots(figsize=(10, 5))

    for rt in retriever_types:
        sub = topk_data[topk_data['retriever_type'] == rt]
        means = sub.groupby('top_k')[PRIMARY_METRIC].agg(['mean', 'std', 'count'])
        means = means.sort_index()
        ci = 1.96 * means['std'] / np.sqrt(means['count'])
        ax.plot(means.index, means['mean'], marker='o', label=rt)
        ax.fill_between(means.index, means['mean'] - ci, means['mean'] + ci, alpha=0.15)

    ax.set_xlabel('Top-K')
    ax.set_ylabel('Mean F1')
    ax.set_title('F1 vs Top-K by Retriever Type')
    ax.legend(title='Retriever')
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("Top-K has <= 1 unique value; skipping sensitivity plot.")

## 5. Interaction Effects

In [None]:
interaction_pairs = [
    ('retriever_type', 'reranker'),
    ('prompt', 'query_transform'),
    ('retriever_type', 'embedding_model'),
    ('reranker', 'query_transform'),
]
# Filter to pairs where both factors have > 1 level
interaction_pairs = [(f1, f2) for f1, f2 in interaction_pairs
                     if f1 in rag.columns and f2 in rag.columns
                     and rag[f1].nunique() > 1 and rag[f2].nunique() > 1]

n_pairs = len(interaction_pairs)
if n_pairs > 0:
    ncols = min(2, n_pairs)
    nrows = (n_pairs + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(7 * ncols, 5 * nrows))
    if n_pairs == 1:
        axes = np.array([axes])
    axes_flat = np.atleast_1d(axes).flatten()

    for idx, (f1, f2) in enumerate(interaction_pairs):
        plot_interaction_heatmap(df, f1, f2, PRIMARY_METRIC, ax=axes_flat[idx])

    for idx in range(n_pairs, len(axes_flat)):
        axes_flat[idx].set_visible(False)

    plt.suptitle('Component Interaction Heatmaps', y=1.01, fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("Not enough factor pairs with multiple levels for interaction analysis.")

In [None]:
# Synergistic / redundant combos
for f1, f2 in interaction_pairs:
    combos = find_synergistic_combinations(df, f1, f2, PRIMARY_METRIC)
    if combos:
        combo_df = pd.DataFrame(combos)
        syn = combo_df[combo_df['synergy'] == 'Synergistic']
        red = combo_df[combo_df['synergy'] == 'Redundant']
        print(f"\n{f1} x {f2}:")
        if len(syn) > 0:
            print(f"  Synergistic ({len(syn)}):")
            for _, r in syn.head(3).iterrows():
                print(f"    {r[f1]} + {r[f2]}: interaction = +{r['interaction_effect']:.4f}")
        if len(red) > 0:
            print(f"  Redundant ({len(red)}):")
            for _, r in red.head(3).iterrows():
                print(f"    {r[f1]} + {r[f2]}: interaction = {r['interaction_effect']:.4f}")

## 6. Optimal Configurations

In [None]:
config_cols = ['retriever_type', 'embedding_model', 'reranker', 'prompt',
               'query_transform', 'top_k', 'agent_type']
config_cols = [c for c in config_cols if c in rag.columns]

# Best config per dataset
print("Best Configuration per Dataset:")
print("=" * 60)
for ds in sorted(rag['dataset'].unique()):
    ds_df = rag[rag['dataset'] == ds]
    if ds_df[PRIMARY_METRIC].notna().sum() == 0:
        continue
    best_idx = ds_df[PRIMARY_METRIC].idxmax()
    best = ds_df.loc[best_idx]
    print(f"\n  {ds} (F1={best[PRIMARY_METRIC]:.4f}):")
    for c in config_cols:
        print(f"    {c:<20s}: {best.get(c, 'n/a')}")

# Best config per model
print("\n\nBest Configuration per Model:")
print("=" * 60)
for model in sorted(rag['model_short'].unique()):
    m_df = rag[rag['model_short'] == model]
    if m_df[PRIMARY_METRIC].notna().sum() == 0:
        continue
    best_idx = m_df[PRIMARY_METRIC].idxmax()
    best = m_df.loc[best_idx]
    print(f"\n  {model} (F1={best[PRIMARY_METRIC]:.4f}):")
    for c in config_cols:
        print(f"    {c:<20s}: {best.get(c, 'n/a')}")

# "Universal recipe" â€” most common values in top-10% of experiments
top_pct = rag.nlargest(max(1, len(rag) // 10), PRIMARY_METRIC)
print("\n\nUniversal Recipe (mode of top-10% experiments):")
print("=" * 60)
for c in config_cols:
    if c in top_pct.columns:
        mode = top_pct[c].mode()
        print(f"  {c:<20s}: {mode.iloc[0] if len(mode) > 0 else 'n/a'}")

## 7. Summary

Key takeaways:
- Which factor explains the most variance
- Best prompt strategy
- Optimal top-K range
- Synergistic and redundant combinations
- Universal vs dataset-specific configurations