In [None]:
# ============================================================================
# Step 6: Configuration & Execution
# ============================================================================

# %% [markdown]
# ## Step 6.1: Define Configurations to Test

# %%
# Define which embedding models, chunk sizes, and re-rankers to evaluate
configurations = [
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_sizes': [512, 1024],
        'k_retrieve': 40,       # Baseline: retrieve 40 documents
        'k_rerank': 20,         # Re-ranking: keep top 20 after re-ranking
        'reranker_models': [
            # 'cross-encoder/ms-marco-MiniLM-L-12-v2',
            # 'BAAI/bge-reranker-large',
            'voyage-rerank-2.5'
        ]
    }
    # {
    #     'provider': 'openai',
    #     'model': 'text-embedding-3-large',
    #     'chunk_sizes': [512, 1024],
    #     'k_retrieve': 40,
    #     'k_rerank': 10,
    #     'reranker_models': [
    #         'cross-encoder/ms-marco-MiniLM-L-12-v2',
    #         'BAAI/bge-reranker-large'
    #     ]
    # },
    # {
    #     'provider': 'ollama',
    #     'model': 'nomic-embed-text',
    #     'chunk_sizes': [1024],
    #     'k_retrieve': 40,
    #     'k_rerank': 20,
    #     'reranker_models': [
    #         'voyage-rerank-2.5'
    #     ]
    # }
]

print("✓ Configurations defined")
print(f"  Total configurations: {len(configurations)}")

# %% [markdown]
# ## Step 6.2: Define Evaluation Parameters

# %%
# Modes to test
modes = ['global', 'singledoc']

# Page tolerance setting
USE_PAGE_TOLERANCE = True

print("✓ Evaluation parameters set")
print(f"  Modes: {modes}")
print(f"  Page Tolerance: {'ENABLED' if USE_PAGE_TOLERANCE else 'DISABLED'}")

# %% [markdown]
# ## Step 6.3: Display Evaluation Plan

# %%
def display_evaluation_plan(configurations, modes):
    """Display what will be evaluated."""
    print(f"\n{'='*60}")
    print("EVALUATION PLAN")
    print(f"{'='*60}")
    
    total_runs = 0
    
    for config in configurations:
        provider = config['provider']
        model = config['model']
        chunk_sizes = config['chunk_sizes']
        k_retrieve = config['k_retrieve']
        k_rerank = config['k_rerank']
        reranker_models = config.get('reranker_models', [])
        
        print(f"\n{provider}/{model}")
        print(f"  Chunk sizes: {chunk_sizes}")
        print(f"  Baseline: retrieve k={k_retrieve}")
        print(f"  Re-ranking: retrieve k={k_retrieve}, keep top k={k_rerank}")
        print(f"  Re-rankers: {len(reranker_models)}")
        for rm in reranker_models:
            print(f"    • {rm}")
        
        # Calculate runs for this config
        # For each chunk + mode: 1 baseline + N re-rankers
        runs_per_config = len(chunk_sizes) * len(modes) * (1 + len(reranker_models))
        total_runs += runs_per_config
        print(f"  Total runs: {runs_per_config} (baseline + {len(reranker_models)} re-rankers)")
    
    print(f"\n{'='*60}")
    print(f"TOTAL EVALUATION RUNS: {total_runs}")
    print(f"{'='*60}")
    
    return total_runs

# %%
total_runs = display_evaluation_plan(configurations, modes)

# %% [markdown]
# ## Step 6.4: Show Output Files That Will Be Created

# %%
def display_output_files(configurations, modes, output_dir):
    """Display all output files that will be created."""
    print(f"\n{'='*60}")
    print("OUTPUT FILES")
    print(f"{'='*60}")
    print(f"Output directory: {output_dir}\n")
    
    all_files = []
    existing_count = 0
    
    for config in configurations:
        provider = config['provider']
        model = config['model']
        chunk_sizes = config['chunk_sizes']
        k_retrieve = config['k_retrieve']
        k_rerank = config['k_rerank']
        reranker_models = config.get('reranker_models', [])
        
        print(f"\n{provider}/{model}:")
        
        for chunk_size in chunk_sizes:
            for mode in modes:
                # Baseline file
                baseline_file = get_output_filename(provider, model, chunk_size, k_retrieve, mode)
                exists = check_if_results_exist(provider, model, chunk_size, k_retrieve, mode, output_dir)
                status = "✓ EXISTS" if exists else "○ TO CREATE"
                if exists:
                    existing_count += 1
                print(f"  {status} {baseline_file}")
                all_files.append((baseline_file, exists))
                
                # Re-ranking files
                for reranker in reranker_models:
                    rerank_file = get_output_filename(provider, model, chunk_size, k_retrieve, mode, reranker, k_rerank)
                    exists = check_if_results_exist(provider, model, chunk_size, k_retrieve, mode, output_dir, reranker, k_rerank)
                    status = "✓ EXISTS" if exists else "○ TO CREATE"
                    if exists:
                        existing_count += 1
                    print(f"  {status} {rerank_file}")
                    all_files.append((rerank_file, exists))
    
    print(f"\n{'='*60}")
    print(f"Total files: {len(all_files)}")
    print(f"  Existing: {existing_count}")
    print(f"  To create: {len(all_files) - existing_count}")
    print(f"{'='*60}")
    
    return all_files

# %%
output_files = display_output_files(configurations, modes, OUTPUT_DIR)

# %% [markdown]
# ## Step 6.5: Execute Batch Evaluation

# %%
print(f"\n{'='*60}")
print("STARTING BATCH EVALUATION")
print(f"{'='*60}")
print(f"This will process {len(configurations)} configuration(s)")
print(f"Existing results will be skipped automatically\n")

# Confirm before starting (optional - comment out to run without confirmation)
# proceed = input("Proceed with evaluation? (y/n): ").lower().strip()
# if proceed != 'y':
#     print("Evaluation cancelled by user")
# else:

# Run evaluation for each configuration
all_results = []

for i, config in enumerate(configurations, 1):
    print(f"\n{'='*60}")
    print(f"PROCESSING CONFIGURATION {i}/{len(configurations)}")
    print(f"{'='*60}")
    
    config_results = evaluate_single_configuration(
        dataset=dataset,
        config=config,
        modes=modes,
        use_page_tolerance=USE_PAGE_TOLERANCE,
        output_dir=OUTPUT_DIR
    )
    
    all_results.extend(config_results)

print(f"\n{'='*60}")
print("BATCH EVALUATION COMPLETE")
print(f"{'='*60}")

# %% [markdown]
# ## Step 6.6: Display Results Summary

# %%
def display_results_summary(all_results):
    """Display summary of all evaluation results."""
    print(f"\n{'='*60}")
    print("EVALUATION RESULTS SUMMARY")
    print(f"{'='*60}")
    
    # Group by configuration
    by_config = defaultdict(list)
    
    for result in all_results:
        if result['result']['status'] == 'completed':
            key = (result['provider'], result['model'], result['chunk_size'], result['mode'], result.get('k_retrieve', 0))
            by_config[key].append(result)
    
    # Display results
    for key, results in sorted(by_config.items()):
        provider, model, chunk_size, mode, k_retrieve = key
        
        print(f"\n{provider}/{model}, chunk={chunk_size}, mode={mode}")
        print("-" * 60)
        
        # Find baseline
        baseline = next((r for r in results if r['type'] == 'baseline'), None)
        if baseline and 'average_mrr' in baseline['result']:
            baseline_mrr = baseline['result']['average_mrr']
            print(f"  Baseline (k={k_retrieve}):  MRR = {baseline_mrr:.4f}")
        else:
            baseline_mrr = None
            print(f"  Baseline:  (skipped)")
        
        # Show re-ranking results
        rerank_results = [r for r in results if r['type'] == 'rerank']
        for r in sorted(rerank_results, key=lambda x: x['result'].get('average_mrr', 0), reverse=True):
            if 'average_mrr' in r['result']:
                mrr = r['result']['average_mrr']
                k_rerank = r.get('k_rerank', 0)
                improvement = r['result'].get('improvement', 0)
                improvement_pct = r['result'].get('improvement_percentage', 0)
                reranker_short = get_reranker_short_name(r['reranker'])
                print(f"  {reranker_short:25s} (k={k_rerank})  MRR = {mrr:.4f}  ({improvement:+.4f}, {improvement_pct:+.2f}%)")
            else:
                reranker_short = get_reranker_short_name(r['reranker'])
                print(f"  {reranker_short:25s} (skipped)")
    
    print(f"\n{'='*60}")

# %%
display_results_summary(all_results)

# %% [markdown]
# ## Step 6.7: List All Generated Files

# %%
def list_generated_files(output_dir):
    """List all JSON files in output directory."""
    print(f"\n{'='*60}")
    print("GENERATED FILES")
    print(f"{'='*60}")
    print(f"Directory: {output_dir}\n")
    
    output_path = Path(output_dir)
    json_files = sorted(output_path.glob("*.json"))
    
    if not json_files:
        print("No JSON files found.")
        return
    
    print(f"Total files: {len(json_files)}\n")
    
    # Group by type
    baseline_files = []
    rerank_files = []
    
    for filepath in json_files:
        file_size = filepath.stat().st_size / 1024  # KB
        if 'rerank-' in filepath.name:
            rerank_files.append((filepath.name, file_size))
        else:
            baseline_files.append((filepath.name, file_size))
    
    if baseline_files:
        print(f"Baseline files ({len(baseline_files)}):")
        for name, size in baseline_files:
            print(f"  {name} ({size:.1f} KB)")
    
    if rerank_files:
        print(f"\nRe-ranking files ({len(rerank_files)}):")
        for name, size in rerank_files:
            print(f"  {name} ({size:.1f} KB)")
    
    total_size = sum(f.stat().st_size for f in json_files) / (1024 * 1024)  # MB
    print(f"\nTotal size: {total_size:.2f} MB")
    print(f"{'='*60}")

# %%
list_generated_files(OUTPUT_DIR)

# %% [markdown]
# ## Step 6.8: Quick Comparison Function

# %%
def quick_comparison(provider, model, chunk_size, mode, k_retrieve, k_rerank=None, output_dir=OUTPUT_DIR):
    """
    Quick comparison of baseline vs all re-rankers for a specific configuration.
    
    Args:
        provider: Embedding provider
        model: Embedding model name
        chunk_size: Chunk size
        mode: "global" or "singledoc"
        k_retrieve: Number retrieved for baseline
        k_rerank: Number kept after re-ranking (optional, will find all if None)
        output_dir: Output directory
    
    Usage:
        quick_comparison('voyage', 'voyage-finance-2', 512, 'global', 40, 10)
    """
    print(f"\n{'='*60}")
    print(f"QUICK COMPARISON")
    print(f"{'='*60}")
    print(f"{provider}/{model}, chunk={chunk_size}, mode={mode}\n")
    
    # Load baseline
    baseline_file = get_output_filename(provider, model, chunk_size, k_retrieve, mode)
    baseline_path = os.path.join(output_dir, baseline_file)
    
    if not os.path.exists(baseline_path):
        print(f"✗ Baseline file not found: {baseline_file}")
        return
    
    with open(baseline_path, 'r') as f:
        baseline_data = json.load(f)
    
    baseline_summary = next((item['summary'] for item in baseline_data if 'summary' in item), None)
    if not baseline_summary:
        print("✗ No summary found in baseline file")
        return
    
    baseline_mrr = baseline_summary['average_mrr']
    baseline_k = baseline_summary.get('k_retrieve', k_retrieve)
    print(f"Baseline (k={baseline_k}): MRR = {baseline_mrr:.4f}\n")
    
    # Find all rerank files for this config
    model_clean = model.replace('/', '_')
    pattern = f"{provider}_{model_clean}_chunk{chunk_size}_k{k_retrieve}_{mode}_rerank_k*"
    
    rerank_files = []
    for f in Path(output_dir).glob(pattern):
        rerank_files.append(f)
    
    if not rerank_files:
        print("No re-ranking files found")
        return
    
    print(f"Re-ranking results:")
    print("-" * 60)
    
    results = []
    for rerank_path in rerank_files:
        with open(rerank_path, 'r') as f:
            rerank_data = json.load(f)
        
        rerank_summary = next((item['summary'] for item in rerank_data if 'summary' in item), None)
        if rerank_summary:
            reranker = rerank_summary['retrieval_config']['reranker_model']
            reranker_short = get_reranker_short_name(reranker)
            k_r = rerank_summary.get('k_rerank', 0)
            mrr = rerank_summary['average_mrr']
            improvement = rerank_summary['mrr_improvement']
            improvement_pct = rerank_summary['mrr_improvement_percentage']
            
            results.append((reranker_short, k_r, mrr, improvement, improvement_pct))
    
    # Sort by MRR descending
    for reranker_short, k_r, mrr, improvement, improvement_pct in sorted(results, key=lambda x: x[2], reverse=True):
        print(f"{reranker_short:25s} (k={k_r})  MRR = {mrr:.4f}  ({improvement:+.4f}, {improvement_pct:+.2f}%)")
    
    print(f"{'='*60}")

print("✓ Quick comparison function defined")

# %% [markdown]
# ## Step 6.9: Example Usage of Quick Comparison

# %%
# Uncomment to use quick comparison after evaluation is complete
"""
# Example: Compare all re-rankers for a specific configuration
quick_comparison(
    provider='voyage',
    model='voyage-finance-2',
    chunk_size=512,
    mode='global',
    k_retrieve=40,
    k_rerank=10
)
"""

# %%
print("\n" + "="*60)
print("✓ STEP 6 COMPLETE - EVALUATION READY!")
print("="*60)
print("\nConfiguration format:")
print("  {")
print("    'provider': 'voyage',")
print("    'model': 'voyage-finance-2',")
print("    'chunk_sizes': [512, 1024],")
print("    'k_retrieve': 40,  # Baseline retrieves 40")
print("    'k_rerank': 10,    # Re-ranking keeps top 10")
print("    'reranker_models': ['cross-encoder/ms-marco-MiniLM-L-12-v2']")
print("  }")
print("\nNext steps:")
print("  1. Review the evaluation plan (6.3)")
print("  2. Check output files (6.4)")
print("  3. Run batch evaluation (6.5)")
print("  4. View results summary (6.6)")
print("  5. Check generated files (6.7)")
print("  6. Use quick_comparison() for specific configs (6.8)")
print("\nAll results are saved to:", OUTPUT_DIR)
print("="*60)

In [28]:
# ============================================================================
# Step 7: Query-Level Analysis
# ============================================================================

# %% [markdown]
# ## Step 7.1: Load Results Functions

# %%
import pandas as pd
from scipy import stats

def load_results_file(filepath):
    """Load a results JSON file."""
    with open(filepath, 'r') as f:
        data = json.load(f)
    
    # Separate queries from summary
    queries = [item for item in data if 'summary' not in item]
    summary = next((item['summary'] for item in data if 'summary' in item), None)
    
    return queries, summary


def load_baseline_and_rerank(provider, model, chunk_size, mode, k_retrieve, k_rerank, reranker_model, output_dir=OUTPUT_DIR):
    """
    Load both baseline and re-ranking results for comparison.
    
    Returns:
        baseline_queries, baseline_summary, rerank_queries, rerank_summary
    """
    # Baseline file
    baseline_file = get_output_filename(provider, model, chunk_size, k_retrieve, mode)
    baseline_path = os.path.join(output_dir, baseline_file)
    
    # Re-ranking file
    rerank_file = get_output_filename(provider, model, chunk_size, k_retrieve, mode, reranker_model, k_rerank)
    rerank_path = os.path.join(output_dir, rerank_file)
    
    if not os.path.exists(baseline_path):
        raise FileNotFoundError(f"Baseline file not found: {baseline_file}")
    if not os.path.exists(rerank_path):
        raise FileNotFoundError(f"Re-ranking file not found: {rerank_file}")
    
    baseline_queries, baseline_summary = load_results_file(baseline_path)
    rerank_queries, rerank_summary = load_results_file(rerank_path)
    
    return baseline_queries, baseline_summary, rerank_queries, rerank_summary


print("✓ Load results functions defined")

# %% [markdown]
# ## Step 7.2: Overall Statistics with Statistical Tests

# %%
def analyze_overall_statistics(baseline_queries, rerank_queries, config_name="Configuration"):
    """
    Compute overall statistics and perform statistical significance tests.
    
    Args:
        baseline_queries: List of baseline query results
        rerank_queries: List of re-ranking query results
        config_name: Name for display
    """
    from scipy import stats as scipy_stats  # Import with alias to avoid conflicts
    
    print(f"\n{'='*60}")
    print(f"OVERALL STATISTICS: {config_name}")
    print(f"{'='*60}")
    
    # Extract MRR scores
    baseline_mrrs = [q['mrr_score'] for q in baseline_queries if 'mrr_score' in q]
    rerank_mrrs = [q['mrr_score'] for q in rerank_queries if 'mrr_score' in q]
    
    # Basic statistics
    avg_baseline = sum(baseline_mrrs) / len(baseline_mrrs)
    avg_rerank = sum(rerank_mrrs) / len(rerank_mrrs)
    improvement = avg_rerank - avg_baseline
    improvement_pct = (improvement / avg_baseline * 100) if avg_baseline > 0 else 0
    
    print(f"\nMRR Scores:")
    print(f"  Baseline:    {avg_baseline:.4f}")
    print(f"  Re-ranking:  {avg_rerank:.4f}")
    print(f"  Improvement: {improvement:+.4f} ({improvement_pct:+.2f}%)")
    
    # Statistical significance test (Paired t-test)
    t_statistic, p_value = scipy_stats.ttest_rel(rerank_mrrs, baseline_mrrs)
    
    print(f"\nStatistical Significance (Paired t-test):")
    print(f"  t-statistic: {t_statistic:.4f}")
    print(f"  p-value:     {p_value:.6f}")
    
    if p_value < 0.001:
        significance = "*** (p < 0.001) - Highly significant"
    elif p_value < 0.01:
        significance = "** (p < 0.01) - Very significant"
    elif p_value < 0.05:
        significance = "* (p < 0.05) - Significant"
    else:
        significance = "(p >= 0.05) - Not significant"
    
    print(f"  Result:      {significance}")
    
    # Query-level improvements
    improvements_count = sum(1 for q in rerank_queries if q.get('rank_improvement', 0) > 0)
    degradations_count = sum(1 for q in rerank_queries if q.get('rank_improvement', 0) < 0)
    unchanged_count = sum(1 for q in rerank_queries if q.get('rank_improvement', 0) == 0)
    
    print(f"\nQuery-Level Changes:")
    print(f"  Improved:   {improvements_count} ({improvements_count/len(rerank_queries)*100:.1f}%)")
    print(f"  Degraded:   {degradations_count} ({degradations_count/len(rerank_queries)*100:.1f}%)")
    print(f"  Unchanged:  {unchanged_count} ({unchanged_count/len(rerank_queries)*100:.1f}%)")
    
    return {
        'avg_baseline': avg_baseline,
        'avg_rerank': avg_rerank,
        'improvement': improvement,
        'improvement_pct': improvement_pct,
        't_statistic': t_statistic,
        'p_value': p_value,
        'improved': improvements_count,
        'degraded': degradations_count,
        'unchanged': unchanged_count
    }

print("✓ Overall statistics function defined")

# %% [markdown]
# ## Step 7.3: Top Improved Queries

# %%
def show_top_improved_queries(rerank_queries, dataset, top_n=10):
    """
    Show queries with the biggest improvements.
    
    Args:
        rerank_queries: List of re-ranking query results
        dataset: FinanceBench dataset (to get question details)
        top_n: Number of top queries to show
    """
    print(f"\n{'='*60}")
    print(f"TOP {top_n} IMPROVED QUERIES")
    print(f"{'='*60}")
    
    # Filter queries with improvements
    improved = [q for q in rerank_queries if q.get('rank_improvement', 0) > 0]
    
    if not improved:
        print("No improved queries found.")
        return []
    
    # Sort by rank improvement (descending)
    improved_sorted = sorted(improved, key=lambda x: x.get('rank_improvement', 0), reverse=True)
    
    for i, query_result in enumerate(improved_sorted[:top_n], 1):
        query_id = query_result['query_id']
        question = query_result['query']
        baseline_rank = query_result.get('rank_baseline', -1)
        rerank_rank = query_result.get('rank', -1)
        improvement = query_result.get('rank_improvement', 0)
        
        # Get additional info from dataset
        dataset_record = next((r for r in dataset if r['financebench_id'] == query_id), None)
        company = dataset_record['company'] if dataset_record else "Unknown"
        question_type = dataset_record.get('question_type', 'N/A') if dataset_record else 'N/A'
        
        print(f"\n[{i}] Improvement: +{improvement} positions")
        print(f"    Company: {company}")
        print(f"    Question Type: {question_type}")
        print(f"    Question: {question[:100]}...")
        print(f"    Baseline rank: {baseline_rank} → Re-rank: {rerank_rank}")
    
    return improved_sorted[:top_n]

print("✓ Top improved queries function defined")

# %% [markdown]
# ## Step 7.4: Degraded Queries Analysis

# %%
def show_degraded_queries(rerank_queries, dataset, top_n=5):
    """
    Show queries that got worse after re-ranking.
    
    Args:
        rerank_queries: List of re-ranking query results
        dataset: FinanceBench dataset
        top_n: Number of queries to show
    """
    print(f"\n{'='*60}")
    print(f"DEGRADED QUERIES (Top {top_n} worst)")
    print(f"{'='*60}")
    
    # Filter degraded queries
    degraded = [q for q in rerank_queries if q.get('rank_improvement', 0) < 0]
    
    if not degraded:
        print("No degraded queries found - excellent!")
        return []
    
    # Sort by rank improvement (ascending - most negative first)
    degraded_sorted = sorted(degraded, key=lambda x: x.get('rank_improvement', 0))
    
    for i, query_result in enumerate(degraded_sorted[:top_n], 1):
        query_id = query_result['query_id']
        question = query_result['query']
        baseline_rank = query_result.get('rank_baseline', -1)
        rerank_rank = query_result.get('rank', -1)
        degradation = query_result.get('rank_improvement', 0)
        
        # Get additional info from dataset
        dataset_record = next((r for r in dataset if r['financebench_id'] == query_id), None)
        company = dataset_record['company'] if dataset_record else "Unknown"
        question_type = dataset_record.get('question_type', 'N/A') if dataset_record else 'N/A'
        
        print(f"\n[{i}] Degradation: {degradation} positions")
        print(f"    Company: {company}")
        print(f"    Question Type: {question_type}")
        print(f"    Question: {question[:100]}...")
        print(f"    Baseline rank: {baseline_rank} → Re-rank: {rerank_rank}")
    
    return degraded_sorted[:top_n]

print("✓ Degraded queries function defined")

# %% [markdown]
# ## Step 7.5: Analysis by Question Type

# %%
def analyze_by_question_type(rerank_queries, dataset):
    """
    Analyze improvements grouped by question type.
    
    Args:
        rerank_queries: List of re-ranking query results
        dataset: FinanceBench dataset
    """
    from scipy import stats as scipy_stats  # Import with alias to avoid conflicts
    
    print(f"\n{'='*60}")
    print(f"ANALYSIS BY QUESTION TYPE")
    print(f"{'='*60}")
    
    # Group by question type
    by_type = defaultdict(list)
    
    for query_result in rerank_queries:
        query_id = query_result['query_id']
        dataset_record = next((r for r in dataset if r['financebench_id'] == query_id), None)
        
        if dataset_record:
            q_type = dataset_record.get('question_type', 'Unknown')
            by_type[q_type].append(query_result)
    
    # Analyze each type
    print(f"\n{'Question Type':<25} {'Count':>7} {'Avg MRR':>10} {'Baseline':>10} {'Improved':>9} {'p-value':>10}")
    print("-" * 85)
    
    type_stats = []
    
    for q_type, queries in sorted(by_type.items()):
        count = len(queries)
        
        # MRR scores
        rerank_mrrs = [q['mrr_score'] for q in queries]
        baseline_mrrs = [q.get('mrr_baseline', 0) for q in queries]
        
        avg_rerank = sum(rerank_mrrs) / len(rerank_mrrs)
        avg_baseline = sum(baseline_mrrs) / len(baseline_mrrs)
        improvement = avg_rerank - avg_baseline
        improvement_pct = (improvement / avg_baseline * 100) if avg_baseline > 0 else 0
        
        # Statistical test
        if len(rerank_mrrs) > 1:
            _, p_value = scipy_stats.ttest_rel(rerank_mrrs, baseline_mrrs)
        else:
            p_value = 1.0
        
        # Significance marker
        if p_value < 0.001:
            sig = "***"
        elif p_value < 0.01:
            sig = "**"
        elif p_value < 0.05:
            sig = "*"
        else:
            sig = ""
        
        print(f"{q_type:<25} {count:>7} {avg_rerank:>10.4f} {avg_baseline:>10.4f} "
              f"{improvement_pct:>+8.1f}% {p_value:>9.6f} {sig}")
        
        type_stats.append({
            'question_type': q_type,
            'count': count,
            'avg_rerank': avg_rerank,
            'avg_baseline': avg_baseline,
            'improvement': improvement,
            'improvement_pct': improvement_pct,
            'p_value': p_value
        })
    
    print("\nSignificance levels: *** p<0.001, ** p<0.01, * p<0.05")
    
    return type_stats

print("✓ Question type analysis function defined")

# %% [markdown]
# ## Step 7.6: Rank Movement Analysis

# %%
def analyze_rank_movements(rerank_queries):
    """
    Analyze how much ranks changed (histogram of rank improvements).
    
    Args:
        rerank_queries: List of re-ranking query results
    """
    print(f"\n{'='*60}")
    print(f"RANK MOVEMENT ANALYSIS")
    print(f"{'='*60}")
    
    # Extract rank improvements
    improvements = [q.get('rank_improvement', 0) for q in rerank_queries]
    
    # Statistics
    avg_improvement = sum(improvements) / len(improvements)
    max_improvement = max(improvements)
    max_degradation = min(improvements)
    
    print(f"\nRank Change Statistics:")
    print(f"  Average change:     {avg_improvement:+.2f} positions")
    print(f"  Best improvement:   +{max_improvement} positions")
    print(f"  Worst degradation:  {max_degradation} positions")
    
    # Distribution
    print(f"\nRank Change Distribution:")
    
    # Create bins
    bins = {
        'Large improvement (+10 or more)': sum(1 for x in improvements if x >= 10),
        'Medium improvement (+5 to +9)': sum(1 for x in improvements if 5 <= x < 10),
        'Small improvement (+1 to +4)': sum(1 for x in improvements if 1 <= x < 5),
        'No change (0)': sum(1 for x in improvements if x == 0),
        'Small degradation (-1 to -4)': sum(1 for x in improvements if -4 <= x < 0),
        'Medium degradation (-5 to -9)': sum(1 for x in improvements if -9 <= x < -5),
        'Large degradation (-10 or less)': sum(1 for x in improvements if x <= -10)
    }
    
    for category, count in bins.items():
        pct = count / len(improvements) * 100
        bar = '█' * int(pct / 2)  # Simple bar chart
        print(f"  {category:<35} {count:>4} ({pct:>5.1f}%) {bar}")
    
    return improvements

print("✓ Rank movement analysis function defined")

# %% [markdown]
# ## Step 7.7: Compare Multiple Configurations

# %%
def compare_configurations(configs_to_compare, dataset, output_dir=OUTPUT_DIR):
    """
    Compare multiple configurations side-by-side.
    
    Args:
        configs_to_compare: List of dicts with provider, model, chunk_size, mode, k_retrieve, k_rerank, reranker_model
        dataset: FinanceBench dataset
        output_dir: Output directory
    """
    print(f"\n{'='*60}")
    print(f"MULTI-CONFIGURATION COMPARISON")
    print(f"{'='*60}")
    
    all_stats = []
    
    for config in configs_to_compare:
        provider = config['provider']
        model = config['model']
        chunk_size = config['chunk_size']
        mode = config['mode']
        k_retrieve = config['k_retrieve']
        k_rerank = config['k_rerank']
        reranker_model = config['reranker_model']
        
        config_name = f"{provider}/{model}, chunk={chunk_size}, mode={mode}"
        
        try:
            # Load results
            baseline_queries, baseline_summary, rerank_queries, rerank_summary = load_baseline_and_rerank(
                provider, model, chunk_size, mode, k_retrieve, k_rerank, reranker_model, output_dir
            )
            
            # Analyze
            stats = analyze_overall_statistics(baseline_queries, rerank_queries, config_name)
            stats['config_name'] = config_name
            stats['config'] = config
            all_stats.append(stats)
            
        except FileNotFoundError as e:
            print(f"\n✗ Skipping {config_name}: {e}")
    
    # Summary table
    if all_stats:
        print(f"\n{'='*60}")
        print(f"SUMMARY TABLE")
        print(f"{'='*60}\n")
        
        print(f"{'Configuration':<40} {'Baseline':>10} {'Re-rank':>10} {'Improve':>9} {'p-value':>10}")
        print("-" * 85)
        
        for stat in all_stats:
            sig = ""
            if stat['p_value'] < 0.001:
                sig = "***"
            elif stat['p_value'] < 0.01:
                sig = "**"
            elif stat['p_value'] < 0.05:
                sig = "*"
            
            print(f"{stat['config_name']:<40} {stat['avg_baseline']:>10.4f} {stat['avg_rerank']:>10.4f} "
                  f"{stat['improvement_pct']:>+8.1f}% {stat['p_value']:>9.6f} {sig}")
        
        print("\nSignificance: *** p<0.001, ** p<0.01, * p<0.05")
    
    return all_stats

print("✓ Multi-config comparison function defined")

# %% [markdown]
# ## Step 7.8: Export to CSV

# %%
def export_analysis_to_csv(rerank_queries, dataset, output_file, output_dir=OUTPUT_DIR):
    """
    Export detailed query-level analysis to CSV.
    
    Args:
        rerank_queries: List of re-ranking query results
        dataset: FinanceBench dataset
        output_file: Output CSV filename
        output_dir: Output directory
    """
    rows = []
    
    for query_result in rerank_queries:
        query_id = query_result['query_id']
        
        # Get dataset info
        dataset_record = next((r for r in dataset if r['financebench_id'] == query_id), None)
        
        row = {
            'query_id': query_id,
            'question': query_result['query'],
            'company': dataset_record['company'] if dataset_record else 'Unknown',
            'question_type': dataset_record.get('question_type', 'N/A') if dataset_record else 'N/A',
            'baseline_rank': query_result.get('rank_baseline', -1),
            'baseline_mrr': query_result.get('mrr_baseline', 0),
            'rerank_rank': query_result.get('rank', -1),
            'rerank_mrr': query_result.get('mrr_score', 0),
            'rank_improvement': query_result.get('rank_improvement', 0),
            'mrr_improvement': query_result.get('mrr_score', 0) - query_result.get('mrr_baseline', 0)
        }
        rows.append(row)
    
    # Create DataFrame and save
    df = pd.DataFrame(rows)
    output_path = os.path.join(output_dir, output_file)
    df.to_csv(output_path, index=False)
    
    print(f"✓ Exported {len(rows)} queries to: {output_file}")
    
    return df

print("✓ Export to CSV function defined")

# %%
print("\n✓ Step 7 complete!")
print("  Functions available:")
print("    • load_baseline_and_rerank() - Load results")
print("    • analyze_overall_statistics() - Overall stats with t-test")
print("    • show_top_improved_queries() - Best improvements")
print("    • show_degraded_queries() - Worst degradations")
print("    • analyze_by_question_type() - Group by question type")
print("    • analyze_rank_movements() - Rank change distribution")
print("    • compare_configurations() - Multi-config comparison")
print("    • export_analysis_to_csv() - Export to CSV")

# %% [markdown]
# ## Step 7.9: Example Usage

# %%
# Example 1: Analyze a single configuration

provider = 'voyage'
model = 'voyage-3-large'
chunk_size = 1024
mode = 'singledoc'
k_retrieve = 40
k_rerank = 20
reranker_model = 'voyage-rerank-2.5'

# Load data
baseline_queries, baseline_summary, rerank_queries, rerank_summary = load_baseline_and_rerank(
    provider, model, chunk_size, mode, k_retrieve, k_rerank, reranker_model
)

# Overall statistics
stats = analyze_overall_statistics(baseline_queries, rerank_queries, 
                                   f"{provider}/{model}, chunk={chunk_size}")

# Top improved
show_top_improved_queries(rerank_queries, dataset, top_n=10)

# Degraded queries
show_degraded_queries(rerank_queries, dataset, top_n=5)

# Question type analysis
type_stats = analyze_by_question_type(rerank_queries, dataset)

# Rank movements
analyze_rank_movements(rerank_queries)

# Export to CSV
export_analysis_to_csv(rerank_queries, dataset, 
                       f"analysis_{provider}_{model}_chunk{chunk_size}_{mode}.csv")


# %%
# Example 2: Compare multiple configurations

configs = [
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 512,
        'mode': 'global',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    },
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 512,
        'mode': 'singledoc',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    },
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 1024,
        'mode': 'global',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    },
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 1024,
        'mode': 'singledoc',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    }
]

all_stats = compare_configurations(configs, dataset)


print("\n" + "="*60)
print("STEP 7 READY FOR USE!")
print("="*60)
print("\nUncomment the examples above and run to analyze your results.")

✓ Load results functions defined
✓ Overall statistics function defined
✓ Top improved queries function defined
✓ Degraded queries function defined
✓ Question type analysis function defined
✓ Rank movement analysis function defined
✓ Multi-config comparison function defined
✓ Export to CSV function defined

✓ Step 7 complete!
  Functions available:
    • load_baseline_and_rerank() - Load results
    • analyze_overall_statistics() - Overall stats with t-test
    • show_top_improved_queries() - Best improvements
    • show_degraded_queries() - Worst degradations
    • analyze_by_question_type() - Group by question type
    • analyze_rank_movements() - Rank change distribution
    • compare_configurations() - Multi-config comparison
    • export_analysis_to_csv() - Export to CSV

OVERALL STATISTICS: voyage/voyage-3-large, chunk=1024

MRR Scores:
  Baseline:    0.6392
  Re-ranking:  0.7844
  Improvement: +0.1451 (+22.70%)

Statistical Significance (Paired t-test):
  t-statistic: 5.2812
  p-

In [None]:
# ============================================================================
# Step 8: Visualization & Comparison
# ============================================================================

# %% [markdown]
# ## Step 8.1: Import Visualization Libraries

# %%
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set style for publication-quality plots
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
plt.rcParams['legend.fontsize'] = 9

print("✓ Visualization libraries imported")

# %% [markdown]
# ## Step 8.2: Baseline vs Re-ranking Comparison (Bar Chart)

# %%
def plot_baseline_vs_reranking(all_stats, output_file=None):
    """
    Create bar chart comparing baseline vs re-ranking MRR.
    
    Args:
        all_stats: List of statistics from compare_configurations()
        output_file: Optional filename to save plot
    """
    if not all_stats:
        print("No statistics to plot")
        return
    
    # Prepare data
    config_names = [s['config_name'] for s in all_stats]
    baseline_mrrs = [s['avg_baseline'] for s in all_stats]
    rerank_mrrs = [s['avg_rerank'] for s in all_stats]
    
    # Shorten config names for display
    short_names = []
    for name in config_names:
        # Extract key info: "chunk=512, mode=global" -> "512-global"
        parts = name.split(', ')
        chunk = parts[1].replace('chunk=', '')
        mode = parts[2].replace('mode=', '')
        short_names.append(f"{chunk}-{mode}")
    
    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    x = np.arange(len(short_names))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, baseline_mrrs, width, label='Baseline', 
                   color='#3498db', alpha=0.8, edgecolor='black', linewidth=0.5)
    bars2 = ax.bar(x + width/2, rerank_mrrs, width, label='Re-ranking', 
                   color='#e74c3c', alpha=0.8, edgecolor='black', linewidth=0.5)
    
    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}',
                   ha='center', va='bottom', fontsize=8)
    
    # Add improvement percentage labels
    for i, stat in enumerate(all_stats):
        improvement_pct = stat['improvement_pct']
        y_pos = max(baseline_mrrs[i], rerank_mrrs[i]) + 0.02
        
        # Add significance stars
        p_val = stat['p_value']
        if p_val < 0.001:
            sig = "***"
        elif p_val < 0.01:
            sig = "**"
        elif p_val < 0.05:
            sig = "*"
        else:
            sig = ""
        
        ax.text(x[i], y_pos, f'+{improvement_pct:.1f}%{sig}',
               ha='center', va='bottom', fontsize=8, fontweight='bold', color='green')
    
    # Formatting
    ax.set_xlabel('Configuration (chunk-mode)', fontweight='bold')
    ax.set_ylabel('Mean Reciprocal Rank (MRR)', fontweight='bold')
    ax.set_title('Baseline vs Re-ranking Performance Comparison', fontweight='bold', pad=20)
    ax.set_xticks(x)
    ax.set_xticklabels(short_names, rotation=0)
    ax.legend(loc='upper left')
    ax.set_ylim(0, max(max(baseline_mrrs), max(rerank_mrrs)) * 1.15)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    
    if output_file:
        plt.savefig(output_file, bbox_inches='tight', dpi=300)
        print(f"✓ Saved plot to: {output_file}")
    
    plt.show()

print("✓ Baseline vs re-ranking plot function defined")

# %% [markdown]
# ## Step 8.3: Improvement Heatmap (Chunk × Mode)

# %%
def plot_improvement_heatmap(all_stats, output_file=None):
    """
    Create heatmap showing improvement percentage by chunk size and mode.
    
    Args:
        all_stats: List of statistics from compare_configurations()
        output_file: Optional filename to save plot
    """
    if not all_stats:
        print("No statistics to plot")
        return
    
    # Extract chunk sizes and modes
    chunk_sizes = []
    modes = []
    improvements = {}
    
    for stat in all_stats:
        config = stat['config']
        chunk = config['chunk_size']
        mode = config['mode']
        improvement_pct = stat['improvement_pct']
        
        if chunk not in chunk_sizes:
            chunk_sizes.append(chunk)
        if mode not in modes:
            modes.append(mode)
        
        improvements[(chunk, mode)] = improvement_pct
    
    # Sort
    chunk_sizes = sorted(chunk_sizes)
    modes = sorted(modes)
    
    # Create matrix
    data = np.zeros((len(chunk_sizes), len(modes)))
    for i, chunk in enumerate(chunk_sizes):
        for j, mode in enumerate(modes):
            data[i, j] = improvements.get((chunk, mode), 0)
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=(8, 6))
    
    im = ax.imshow(data, cmap='RdYlGn', aspect='auto', vmin=0, vmax=max(improvements.values()))
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('MRR Improvement (%)', rotation=270, labelpad=20, fontweight='bold')
    
    # Set ticks
    ax.set_xticks(np.arange(len(modes)))
    ax.set_yticks(np.arange(len(chunk_sizes)))
    ax.set_xticklabels(modes)
    ax.set_yticklabels(chunk_sizes)
    
    # Add text annotations
    for i in range(len(chunk_sizes)):
        for j in range(len(modes)):
            text = ax.text(j, i, f'{data[i, j]:.1f}%',
                          ha="center", va="center", color="black", fontweight='bold')
    
    # Labels
    ax.set_xlabel('Retrieval Mode', fontweight='bold')
    ax.set_ylabel('Chunk Size', fontweight='bold')
    ax.set_title('MRR Improvement Heatmap: Chunk Size × Mode', fontweight='bold', pad=20)
    
    plt.tight_layout()
    
    if output_file:
        plt.savefig(output_file, bbox_inches='tight', dpi=300)
        print(f"✓ Saved plot to: {output_file}")
    
    plt.show()

print("✓ Improvement heatmap function defined")

# %% [markdown]
# ## Step 8.4: Question Type Performance (Bar Chart)

# %%
def plot_question_type_performance(type_stats, output_file=None):
    """
    Create bar chart showing performance by question type.
    
    Args:
        type_stats: Output from analyze_by_question_type()
        output_file: Optional filename to save plot
    """
    if not type_stats:
        print("No question type statistics to plot")
        return
    
    # Prepare data
    question_types = [s['question_type'] for s in type_stats]
    baseline_mrrs = [s['avg_baseline'] for s in type_stats]
    rerank_mrrs = [s['avg_rerank'] for s in type_stats]
    p_values = [s['p_value'] for s in type_stats]
    
    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    x = np.arange(len(question_types))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, baseline_mrrs, width, label='Baseline', 
                   color='#3498db', alpha=0.8, edgecolor='black', linewidth=0.5)
    bars2 = ax.bar(x + width/2, rerank_mrrs, width, label='Re-ranking', 
                   color='#e74c3c', alpha=0.8, edgecolor='black', linewidth=0.5)
    
    # Add value labels
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}',
                   ha='center', va='bottom', fontsize=8)
    
    # Add improvement and significance
    for i, stat in enumerate(type_stats):
        improvement_pct = stat['improvement_pct']
        p_val = stat['p_value']
        
        y_pos = max(baseline_mrrs[i], rerank_mrrs[i]) + 0.03
        
        # Significance marker
        if p_val < 0.001:
            sig = "***"
        elif p_val < 0.01:
            sig = "**"
        elif p_val < 0.05:
            sig = "*"
        else:
            sig = ""
        
        color = 'green' if improvement_pct > 0 else 'red'
        ax.text(x[i], y_pos, f'{improvement_pct:+.1f}%{sig}',
               ha='center', va='bottom', fontsize=9, fontweight='bold', color=color)
    
    # Formatting
    ax.set_xlabel('Question Type', fontweight='bold')
    ax.set_ylabel('Mean Reciprocal Rank (MRR)', fontweight='bold')
    ax.set_title('Re-ranking Performance by Question Type', fontweight='bold', pad=20)
    ax.set_xticks(x)
    ax.set_xticklabels(question_types, rotation=15, ha='right')
    ax.legend(loc='upper left')
    ax.set_ylim(0, max(max(baseline_mrrs), max(rerank_mrrs)) * 1.2)
    ax.grid(axis='y', alpha=0.3)
    
    # Add note
    ax.text(0.02, 0.98, 'Significance: *** p<0.001, ** p<0.01, * p<0.05',
           transform=ax.transAxes, fontsize=8, verticalalignment='top',
           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
    
    plt.tight_layout()
    
    if output_file:
        plt.savefig(output_file, bbox_inches='tight', dpi=300)
        print(f"✓ Saved plot to: {output_file}")
    
    plt.show()

print("✓ Question type performance plot function defined")

# %% [markdown]
# ## Step 8.5: Rank Improvement Distribution (Histogram)

# %%
def plot_rank_improvement_distribution(improvements, output_file=None):
    """
    Create histogram of rank improvements.
    
    Args:
        improvements: List of rank improvements from analyze_rank_movements()
        output_file: Optional filename to save plot
    """
    if not improvements:
        print("No improvements data to plot")
        return
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create histogram
    bins = range(min(improvements)-1, max(improvements)+2, 1)
    n, bins, patches = ax.hist(improvements, bins=bins, edgecolor='black', 
                               linewidth=0.5, alpha=0.7, color='steelblue')
    
    # Color bars: green for improvements, red for degradations, gray for no change
    for i, patch in enumerate(patches):
        bin_center = (bins[i] + bins[i+1]) / 2
        if bin_center > 0:
            patch.set_facecolor('green')
            patch.set_alpha(0.7)
        elif bin_center < 0:
            patch.set_facecolor('red')
            patch.set_alpha(0.7)
        else:
            patch.set_facecolor('gray')
            patch.set_alpha(0.5)
    
    # Add vertical line at zero
    ax.axvline(x=0, color='black', linestyle='--', linewidth=1.5, alpha=0.7)
    
    # Statistics
    improved_pct = sum(1 for x in improvements if x > 0) / len(improvements) * 100
    degraded_pct = sum(1 for x in improvements if x < 0) / len(improvements) * 100
    unchanged_pct = sum(1 for x in improvements if x == 0) / len(improvements) * 100
    
    # Add text box with statistics
    stats_text = f'Improved: {improved_pct:.1f}%\nUnchanged: {unchanged_pct:.1f}%\nDegraded: {degraded_pct:.1f}%'
    ax.text(0.98, 0.97, stats_text, transform=ax.transAxes, fontsize=10,
           verticalalignment='top', horizontalalignment='right',
           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor='black'))
    
    # Formatting
    ax.set_xlabel('Rank Change (Positive = Improvement)', fontweight='bold')
    ax.set_ylabel('Number of Queries', fontweight='bold')
    ax.set_title('Distribution of Rank Changes After Re-ranking', fontweight='bold', pad=20)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    
    if output_file:
        plt.savefig(output_file, bbox_inches='tight', dpi=300)
        print(f"✓ Saved plot to: {output_file}")
    
    plt.show()

print("✓ Rank improvement distribution plot function defined")

# %% [markdown]
# ## Step 8.6: Baseline vs Re-rank Scatter Plot

# %%
def plot_baseline_vs_rerank_scatter(baseline_queries, rerank_queries, output_file=None):
    """
    Create scatter plot showing baseline rank vs re-rank rank.
    
    Args:
        baseline_queries: Baseline query results
        rerank_queries: Re-ranking query results
        output_file: Optional filename to save plot
    """
    # Extract ranks (only for queries where we have valid ranks)
    baseline_ranks = []
    rerank_ranks = []
    
    for rq in rerank_queries:
        baseline_rank = rq.get('rank_baseline', -1)
        rerank_rank = rq.get('rank', -1)
        
        if baseline_rank > 0 and rerank_rank > 0:
            baseline_ranks.append(baseline_rank)
            rerank_ranks.append(rerank_rank)
    
    if not baseline_ranks:
        print("No valid rank data to plot")
        return
    
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Scatter plot
    ax.scatter(baseline_ranks, rerank_ranks, alpha=0.5, s=50, color='steelblue', edgecolor='black', linewidth=0.5)
    
    # Add diagonal line (y=x) - represents no change
    max_rank = max(max(baseline_ranks), max(rerank_ranks))
    ax.plot([0, max_rank], [0, max_rank], 'r--', linewidth=2, alpha=0.7, label='No change (y=x)')
    
    # Count improvements/degradations
    improvements = sum(1 for b, r in zip(baseline_ranks, rerank_ranks) if r < b)
    degradations = sum(1 for b, r in zip(baseline_ranks, rerank_ranks) if r > b)
    unchanged = sum(1 for b, r in zip(baseline_ranks, rerank_ranks) if r == b)
    
    # Add shaded regions
    ax.fill_between([0, max_rank], [0, max_rank], max_rank, alpha=0.1, color='red', 
                    label=f'Degraded ({degradations})')
    ax.fill_between([0, max_rank], 0, [0, max_rank], alpha=0.1, color='green', 
                    label=f'Improved ({improvements})')
    
    # Formatting
    ax.set_xlabel('Baseline Rank', fontweight='bold')
    ax.set_ylabel('Re-ranking Rank', fontweight='bold')
    ax.set_title('Rank Position Comparison: Baseline vs Re-ranking', fontweight='bold', pad=20)
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')
    
    # Invert axes (rank 1 is best, should be at top)
    ax.invert_yaxis()
    ax.invert_xaxis()
    
    plt.tight_layout()
    
    if output_file:
        plt.savefig(output_file, bbox_inches='tight', dpi=300)
        print(f"✓ Saved plot to: {output_file}")
    
    plt.show()

print("✓ Baseline vs re-rank scatter plot function defined")

# %% [markdown]
# ## Step 8.7: Multi-Panel Summary Figure

# %%
def create_summary_figure(all_stats, type_stats, improvements, output_file=None):
    """
    Create a multi-panel figure with all key visualizations.
    
    Args:
        all_stats: Statistics from compare_configurations()
        type_stats: Statistics from analyze_by_question_type()
        improvements: Rank improvements list
        output_file: Optional filename to save plot
    """
    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
    
    # Panel 1: Baseline vs Re-ranking bars
    ax1 = fig.add_subplot(gs[0, 0])
    config_names = [s['config_name'].split(', ')[1:] for s in all_stats]
    short_names = [f"{c[0].replace('chunk=', '')}-{c[1].replace('mode=', '')}" for c in config_names]
    baseline_mrrs = [s['avg_baseline'] for s in all_stats]
    rerank_mrrs = [s['avg_rerank'] for s in all_stats]
    
    x = np.arange(len(short_names))
    width = 0.35
    ax1.bar(x - width/2, baseline_mrrs, width, label='Baseline', color='#3498db', alpha=0.8)
    ax1.bar(x + width/2, rerank_mrrs, width, label='Re-ranking', color='#e74c3c', alpha=0.8)
    ax1.set_xlabel('Configuration', fontweight='bold')
    ax1.set_ylabel('MRR', fontweight='bold')
    ax1.set_title('(A) Baseline vs Re-ranking Performance', fontweight='bold')
    ax1.set_xticks(x)
    ax1.set_xticklabels(short_names, rotation=15, ha='right')
    ax1.legend()
    ax1.grid(axis='y', alpha=0.3)
    
    # Panel 2: Question type performance
    ax2 = fig.add_subplot(gs[0, 1])
    q_types = [s['question_type'] for s in type_stats]
    improvements_pct = [s['improvement_pct'] for s in type_stats]
    colors = ['green' if x > 0 else 'red' for x in improvements_pct]
    ax2.barh(q_types, improvements_pct, color=colors, alpha=0.7, edgecolor='black')
    ax2.axvline(x=0, color='black', linestyle='-', linewidth=1)
    ax2.set_xlabel('MRR Improvement (%)', fontweight='bold')
    ax2.set_title('(B) Improvement by Question Type', fontweight='bold')
    ax2.grid(axis='x', alpha=0.3)
    
    # Panel 3: Rank improvement histogram
    ax3 = fig.add_subplot(gs[1, 0])
    bins = range(min(improvements)-1, max(improvements)+2, 1)
    n, bins_edges, patches = ax3.hist(improvements, bins=bins, edgecolor='black', alpha=0.7)
    for i, patch in enumerate(patches):
        bin_center = (bins_edges[i] + bins_edges[i+1]) / 2
        if bin_center > 0:
            patch.set_facecolor('green')
        elif bin_center < 0:
            patch.set_facecolor('red')
        else:
            patch.set_facecolor('gray')
    ax3.axvline(x=0, color='black', linestyle='--', linewidth=1.5)
    ax3.set_xlabel('Rank Change', fontweight='bold')
    ax3.set_ylabel('Frequency', fontweight='bold')
    ax3.set_title('(C) Distribution of Rank Changes', fontweight='bold')
    ax3.grid(axis='y', alpha=0.3)
    
    # Panel 4: Summary statistics table
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.axis('off')
    
    # Create summary table
    summary_data = []
    for stat in all_stats:
        config = stat['config_name'].split(', ')
        chunk_mode = f"{config[1].replace('chunk=', '')}-{config[2].replace('mode=', '')}"
        summary_data.append([
            chunk_mode,
            f"{stat['avg_baseline']:.3f}",
            f"{stat['avg_rerank']:.3f}",
            f"+{stat['improvement_pct']:.1f}%",
            "***" if stat['p_value'] < 0.001 else "**" if stat['p_value'] < 0.01 else "*" if stat['p_value'] < 0.05 else ""
        ])
    
    table = ax4.table(cellText=summary_data,
                     colLabels=['Config', 'Baseline', 'Re-rank', 'Improve', 'Sig.'],
                     cellLoc='center',
                     loc='center',
                     bbox=[0, 0, 1, 1])
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1, 2)
    
    # Style header
    for i in range(5):
        table[(0, i)].set_facecolor('#3498db')
        table[(0, i)].set_text_props(weight='bold', color='white')
    
    # Alternate row colors
    for i in range(1, len(summary_data) + 1):
        for j in range(5):
            if i % 2 == 0:
                table[(i, j)].set_facecolor('#ecf0f1')
    
    ax4.set_title('(D) Summary Statistics', fontweight='bold', pad=20)
    
    # Main title
    fig.suptitle('Re-ranking Evaluation Summary', fontsize=16, fontweight='bold', y=0.98)
    
    if output_file:
        plt.savefig(output_file, bbox_inches='tight', dpi=300)
        print(f"✓ Saved summary figure to: {output_file}")
    
    plt.show()

print("✓ Multi-panel summary figure function defined")

# %%
print("\n✓ Step 8 complete!")
print("  Visualization functions available:")
print("    • plot_baseline_vs_reranking() - Bar chart comparison")
print("    • plot_improvement_heatmap() - Chunk × Mode heatmap")
print("    • plot_question_type_performance() - Question type bars")
print("    • plot_rank_improvement_distribution() - Histogram")
print("    • plot_baseline_vs_rerank_scatter() - Scatter plot")
print("    • create_summary_figure() - Multi-panel summary")

# %% [markdown]
# ## Step 8.8: Example Usage

# %%
# Example: Create all visualizations for a configuration

# First, run Step 7 analysis
PLOT_OUTPUT_DIR = OUTPUT_DIR + "/plots/"
os.makedirs(PLOT_OUTPUT_DIR, exist_ok=True)
configs = [
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 512,
        'mode': 'global',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    },
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 512,
        'mode': 'singledoc',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    },
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 1024,
        'mode': 'global',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    },
    {
        'provider': 'voyage',
        'model': 'voyage-3-large',
        'chunk_size': 1024,
        'mode': 'singledoc',
        'k_retrieve': 40,
        'k_rerank': 20,
        'reranker_model': 'voyage-rerank-2.5'
    }
]

# Get statistics
all_stats = compare_configurations(configs, dataset)

# Load one config for detailed analysis
baseline_queries, baseline_summary, rerank_queries, rerank_summary = load_baseline_and_rerank(
    'voyage', 'voyage-3-large', 1024, 'singledoc', 40, 20, 'voyage-rerank-2.5'
)

# Get question type stats and rank improvements
type_stats = analyze_by_question_type(rerank_queries, dataset)
improvements = analyze_rank_movements(rerank_queries)

# Create visualizations
plot_baseline_vs_reranking(all_stats, PLOT_OUTPUT_DIR + 'fig_baseline_vs_rerank.png')
plot_improvement_heatmap(all_stats, PLOT_OUTPUT_DIR + 'fig_improvement_heatmap.png')
plot_question_type_performance(type_stats, PLOT_OUTPUT_DIR + 'fig_question_types.png')
plot_rank_improvement_distribution(improvements, PLOT_OUTPUT_DIR + 'fig_rank_distribution.png')
plot_baseline_vs_rerank_scatter(baseline_queries, rerank_queries, PLOT_OUTPUT_DIR + 'fig_scatter.png')
create_summary_figure(all_stats, type_stats, improvements, PLOT_OUTPUT_DIR + 'fig_summary.png')


print("\n" + "="*60)
print("STEP 8 READY FOR USE!")
print("="*60)
print("\nUncomment the examples above to create visualizations.")
print("All plots are publication-quality (300 DPI) and can be saved as PNG files.")