# Basketball Shot Analysis - Model Comparison & Evaluation

This notebook compares different models, prompts, and configurations for basketball shot analysis.

## Evaluation Metrics

- **Accuracy**: Make/miss and shot type classification accuracy
- **Latency**: Average inference time per sample
- **Cost**: Estimated cost per API call
- **Reliability**: Parse success rate and error analysis

## Models Tested

- Gemini 2.5 Flash (production)
- Gemini 2.5 Flash Lite (fast)
- Gemini 2.5 Pro (high quality)

## Usage

1. Ensure data is synced (run `1_data_sync.ipynb` first)
2. Set your API keys in environment variables
3. Configure evaluation parameters below
4. Run evaluation and analysis


In [None]:
import sys
import os
import json
import yaml
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dotenv import load_dotenv
from tqdm import tqdm

# Add src directory to path
sys.path.append(str(Path('../src').resolve()))

from data_manager import EvaluationDataManager
from evaluator import ModelEvaluator
from metrics import EvaluationMetrics

# Load environment variables
load_dotenv('../../.env')

# Configuration
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
TEST_MODE = True  # Set to False for full evaluation
MAX_SAMPLES = 5 if TEST_MODE else None  # Limit samples for testing

if not GEMINI_API_KEY:
    print("⚠️  Missing GEMINI_API_KEY!")
    print("Please set your Gemini API key in the .env file")
else:
    print("✅ API key loaded")

print(f"🔧 Test mode: {TEST_MODE}")
if TEST_MODE:
    print(f"📊 Will evaluate on {MAX_SAMPLES} samples only")


In [None]:
# Load configurations
with open('../configs/models.yaml', 'r') as f:
    model_config = yaml.safe_load(f)

with open('../configs/prompts.yaml', 'r') as f:
    prompt_config = yaml.safe_load(f)

print("📋 Available Models:")
for model_key, model_info in model_config['models'].items():
    print(f"  {model_key}: {model_info['name']} - {model_info['description']}")

print("\n📝 Available Prompts:")
for prompt_key, prompt_info in prompt_config['prompts'].items():
    print(f"  {prompt_key}: {prompt_info['name']} - {prompt_info['description']}")

# Select models and prompts to evaluate
MODELS_TO_TEST = ['gemini_flash', 'gemini_flash_lite']  # Add 'gemini_pro' for full eval
PROMPTS_TO_TEST = ['current_production', 'simplified']  # Add more for comprehensive testing
VIDEO_CONFIG = 'standard'  # standard, high_quality, or fast

print(f"\n🎯 Evaluation Plan:")
print(f"Models: {MODELS_TO_TEST}")
print(f"Prompts: {PROMPTS_TO_TEST}")
print(f"Video config: {VIDEO_CONFIG}")
print(f"Total combinations: {len(MODELS_TO_TEST) * len(PROMPTS_TO_TEST)}")


In [None]:
# Load ground truth data
data_manager = EvaluationDataManager(
    supabase_url="dummy",  # Not needed for loading
    supabase_key="dummy",
    data_dir='../data'
)

try:
    ground_truth = data_manager.load_ground_truth()
    print(f"✅ Loaded {len(ground_truth)} ground truth samples")
    
    # Show dataset overview
    shot_types = [entry['ground_truth']['shot_type'] for entry in ground_truth]
    results = [entry['ground_truth']['result'] for entry in ground_truth]
    
    print(f"\n📊 Dataset Overview:")
    print(f"Shot types: {pd.Series(shot_types).value_counts().to_dict()}")
    print(f"Results: {pd.Series(results).value_counts().to_dict()}")
    
except FileNotFoundError:
    print("❌ Ground truth data not found!")
    print("Please run the data synchronization notebook first (1_data_sync.ipynb)")
    ground_truth = None


In [None]:
# Initialize evaluator
evaluator = ModelEvaluator(
    gemini_api_key=GEMINI_API_KEY,
    output_dir='../data/model_outputs'
)

# Run evaluations
evaluation_results = []

if ground_truth:
    video_settings = model_config['video_configs'][VIDEO_CONFIG]
    
    for model_key in MODELS_TO_TEST:
        model_info = model_config['models'][model_key]
        model_name = model_info['name']
        
        for prompt_key in PROMPTS_TO_TEST:
            prompt_info = prompt_config['prompts'][prompt_key]
            prompt_text = prompt_info['content']
            
            print(f"\n🚀 Evaluating {model_key} with {prompt_key} prompt...")
            print(f"Model: {model_name}")
            print(f"Video: {video_settings['fps']}fps, {video_settings['resolution']} resolution")
            
            try:
                result = evaluator.evaluate_model(
                    model_name=model_name,
                    prompt=prompt_text,
                    ground_truth_data=ground_truth,
                    fps=video_settings['fps'],
                    media_resolution=video_settings['resolution'],
                    max_samples=MAX_SAMPLES
                )
                
                # Add configuration info
                result['model_key'] = model_key
                result['prompt_key'] = prompt_key
                result['video_config'] = VIDEO_CONFIG
                
                evaluation_results.append(result)
                
                print(f"✅ Completed {model_key} + {prompt_key}")
                print(f"   Accuracy: {result['metrics']['both_correct_accuracy']:.3f}")
                print(f"   Avg time: {result['average_inference_time_s']:.2f}s")
                print(f"   Total cost: ${result['total_cost_usd']:.4f}")
                
            except Exception as e:
                print(f"❌ Error evaluating {model_key} + {prompt_key}: {e}")
                continue

print(f"\n🏁 Evaluation complete! {len(evaluation_results)} combinations tested.")


In [None]:
# Create comparison DataFrame
if evaluation_results:
    comparison_data = []
    
    for result in evaluation_results:
        row = {
            'model': result['model_key'],
            'prompt': result['prompt_key'],
            'model_name': result['model_name'],
            'total_samples': result['total_samples'],
            'shot_type_accuracy': result['metrics']['shot_type_accuracy'],
            'result_accuracy': result['metrics']['result_accuracy'],
            'both_correct_accuracy': result['metrics']['both_correct_accuracy'],
            'parse_success_rate': result['metrics']['parse_success_rate'],
            'average_confidence': result['metrics']['average_confidence'],
            'avg_inference_time_s': result['average_inference_time_s'],
            'total_cost_usd': result['total_cost_usd'],
            'cost_per_sample': result['total_cost_usd'] / result['total_samples'] if result['total_samples'] > 0 else 0
        }
        comparison_data.append(row)
    
    comparison_df = pd.DataFrame(comparison_data)
    
    print("📊 Evaluation Results Summary:")
    print(comparison_df[['model', 'prompt', 'both_correct_accuracy', 'avg_inference_time_s', 'cost_per_sample']].round(4))
    
else:
    print("❌ No evaluation results to display")
    comparison_df = None


In [None]:
# Detailed analysis and visualizations
if comparison_df is not None and len(comparison_df) > 0:
    
    # Performance comparison chart
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Accuracy comparison
    ax1 = axes[0, 0]
    x_pos = np.arange(len(comparison_df))
    width = 0.35
    
    ax1.bar(x_pos - width/2, comparison_df['shot_type_accuracy'], width, label='Shot Type', alpha=0.8)
    ax1.bar(x_pos + width/2, comparison_df['result_accuracy'], width, label='Make/Miss', alpha=0.8)
    
    ax1.set_xlabel('Model + Prompt')
    ax1.set_ylabel('Accuracy')
    ax1.set_title('Classification Accuracy Comparison')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels([f"{row['model']}\n{row['prompt']}" for _, row in comparison_df.iterrows()], rotation=45)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Speed vs Accuracy
    ax2 = axes[0, 1]
    scatter = ax2.scatter(comparison_df['avg_inference_time_s'], comparison_df['both_correct_accuracy'], 
                         s=100, alpha=0.7, c=comparison_df['cost_per_sample'], cmap='viridis')
    
    for i, row in comparison_df.iterrows():
        ax2.annotate(f"{row['model'][:8]}\n{row['prompt'][:8]}", 
                    (row['avg_inference_time_s'], row['both_correct_accuracy']),
                    xytext=(5, 5), textcoords='offset points', fontsize=8)
    
    ax2.set_xlabel('Average Inference Time (s)')
    ax2.set_ylabel('Both Correct Accuracy')
    ax2.set_title('Speed vs Accuracy (Color = Cost)')
    plt.colorbar(scatter, ax=ax2, label='Cost per Sample ($)')
    ax2.grid(True, alpha=0.3)
    
    # Cost analysis
    ax3 = axes[1, 0]
    bars = ax3.bar(range(len(comparison_df)), comparison_df['cost_per_sample'])
    ax3.set_xlabel('Model + Prompt')
    ax3.set_ylabel('Cost per Sample ($)')
    ax3.set_title('Cost per Sample Comparison')
    ax3.set_xticks(range(len(comparison_df)))
    ax3.set_xticklabels([f"{row['model']}\n{row['prompt']}" for _, row in comparison_df.iterrows()], rotation=45)
    
    # Add accuracy as text on bars
    for i, (bar, acc) in enumerate(zip(bars, comparison_df['both_correct_accuracy'])):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(comparison_df['cost_per_sample'])*0.01,
                f'{acc:.3f}', ha='center', va='bottom', fontsize=9)
    
    ax3.grid(True, alpha=0.3)
    
    # Overall performance radar (if multiple models)
    ax4 = axes[1, 1]
    if len(comparison_df) > 1:
        # Normalize metrics for radar chart
        metrics_to_plot = ['both_correct_accuracy', 'parse_success_rate', 'average_confidence']
        normalized_data = comparison_df[metrics_to_plot].copy()
        
        for col in metrics_to_plot:
            max_val = normalized_data[col].max()
            if max_val > 0:
                normalized_data[col] = normalized_data[col] / max_val
        
        # Simple bar chart instead of radar for simplicity
        x_pos = np.arange(len(metrics_to_plot))
        width = 0.8 / len(comparison_df)
        
        for i, (_, row) in enumerate(comparison_df.iterrows()):
            values = [normalized_data.iloc[i][col] for col in metrics_to_plot]
            ax4.bar(x_pos + i * width, values, width, 
                   label=f"{row['model']} + {row['prompt']}", alpha=0.7)
        
        ax4.set_xlabel('Metrics')
        ax4.set_ylabel('Normalized Score')
        ax4.set_title('Overall Performance Comparison')
        ax4.set_xticks(x_pos + width * (len(comparison_df) - 1) / 2)
        ax4.set_xticklabels(['Both Correct', 'Parse Success', 'Avg Confidence'])
        ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, 'Need multiple models\nfor comparison', 
                ha='center', va='center', transform=ax4.transAxes, fontsize=12)
        ax4.set_title('Overall Performance Comparison')
    
    plt.tight_layout()
    plt.show()
    
else:
    print("❌ No data available for visualization")


In [None]:
# Detailed error analysis
if evaluation_results:
    metrics_calculator = EvaluationMetrics()
    
    print("🔍 Detailed Error Analysis:")
    
    for i, result in enumerate(evaluation_results):
        model_key = result['model_key']
        prompt_key = result['prompt_key']
        
        print(f"\n{'='*50}")
        print(f"📊 {model_key.upper()} + {prompt_key.upper()}")
        print(f"{'='*50}")
        
        metrics = result['metrics']
        
        # Overall performance
        print(f"📈 Performance Summary:")
        print(f"  Total samples: {result['total_samples']}")
        print(f"  Parse success: {metrics['parse_success_rate']:.3f}")
        print(f"  Shot type accuracy: {metrics['shot_type_accuracy']:.3f}")
        print(f"  Make/miss accuracy: {metrics['result_accuracy']:.3f}")
        print(f"  Both correct: {metrics['both_correct_accuracy']:.3f}")
        print(f"  Average confidence: {metrics['average_confidence']:.3f}")
        print(f"  Avg inference time: {result['average_inference_time_s']:.2f}s")
        print(f"  Cost per sample: ${result['total_cost_usd']/result['total_samples']:.4f}")
        
        # Confusion matrices
        print(f"\n🎯 Shot Type Confusion:")
        shot_type_cm = metrics.get('shot_type_confusion_matrix', {})
        for transition, count in sorted(shot_type_cm.items()):
            if count > 0:
                print(f"  {transition}: {count}")
        
        print(f"\n🎯 Make/Miss Confusion:")
        result_cm = metrics.get('result_confusion_matrix', {})
        for transition, count in sorted(result_cm.items()):
            if count > 0:
                print(f"  {transition}: {count}")
        
        # Error analysis
        error_df = metrics_calculator.generate_error_analysis(result)
        if len(error_df) > 0:
            print(f"\n❌ Error Breakdown ({len(error_df)} errors):")
            print(f"  Shot type errors: {error_df['shot_type_error'].sum()}")
            print(f"  Make/miss errors: {error_df['result_error'].sum()}")
            print(f"  Parse errors: {error_df['has_parse_error'].sum()}")
            
            if len(error_df) > 0:
                print(f"\n🔍 Sample Errors:")
                for _, row in error_df.head(3).iterrows():
                    print(f"  Clip {row['clip_id'][:8]}:")
                    if row['shot_type_error']:
                        print(f"    Shot type: {row['gt_shot_type']} → {row['pred_shot_type']}")
                    if row['result_error']:
                        print(f"    Result: {row['gt_result']} → {row['pred_result']}")
                    print(f"    Confidence: {row['confidence']:.3f}")
        else:
            print(f"\n✅ No errors found!")


In [None]:
# Recommendations and next steps
if comparison_df is not None and len(comparison_df) > 0:
    print("🎯 RECOMMENDATIONS")
    print("="*50)
    
    # Find best performers
    best_accuracy = comparison_df.loc[comparison_df['both_correct_accuracy'].idxmax()]
    fastest = comparison_df.loc[comparison_df['avg_inference_time_s'].idxmin()]
    cheapest = comparison_df.loc[comparison_df['cost_per_sample'].idxmin()]
    
    print(f"🏆 BEST ACCURACY: {best_accuracy['model']} + {best_accuracy['prompt']}")
    print(f"   Accuracy: {best_accuracy['both_correct_accuracy']:.3f}")
    print(f"   Speed: {best_accuracy['avg_inference_time_s']:.2f}s")
    print(f"   Cost: ${best_accuracy['cost_per_sample']:.4f}")
    
    print(f"\n⚡ FASTEST: {fastest['model']} + {fastest['prompt']}")
    print(f"   Speed: {fastest['avg_inference_time_s']:.2f}s")
    print(f"   Accuracy: {fastest['both_correct_accuracy']:.3f}")
    print(f"   Cost: ${fastest['cost_per_sample']:.4f}")
    
    print(f"\n💰 CHEAPEST: {cheapest['model']} + {cheapest['prompt']}")
    print(f"   Cost: ${cheapest['cost_per_sample']:.4f}")
    print(f"   Accuracy: {cheapest['both_correct_accuracy']:.3f}")
    print(f"   Speed: {cheapest['avg_inference_time_s']:.2f}s")
    
    # Production recommendation
    print(f"\n🚀 PRODUCTION RECOMMENDATION:")
    
    # Calculate a balanced score (accuracy * 0.6 + speed_score * 0.2 + cost_score * 0.2)
    comparison_df_scored = comparison_df.copy()
    
    # Normalize scores (higher is better)
    comparison_df_scored['speed_score'] = 1 / comparison_df_scored['avg_inference_time_s']
    comparison_df_scored['cost_score'] = 1 / comparison_df_scored['cost_per_sample']
    
    # Normalize to 0-1 range
    comparison_df_scored['speed_score'] = (comparison_df_scored['speed_score'] - comparison_df_scored['speed_score'].min()) / (comparison_df_scored['speed_score'].max() - comparison_df_scored['speed_score'].min())
    comparison_df_scored['cost_score'] = (comparison_df_scored['cost_score'] - comparison_df_scored['cost_score'].min()) / (comparison_df_scored['cost_score'].max() - comparison_df_scored['cost_score'].min())
    
    # Combined score
    comparison_df_scored['combined_score'] = (
        comparison_df_scored['both_correct_accuracy'] * 0.6 +
        comparison_df_scored['speed_score'] * 0.2 +
        comparison_df_scored['cost_score'] * 0.2
    )
    
    best_overall = comparison_df_scored.loc[comparison_df_scored['combined_score'].idxmax()]
    
    print(f"   Best balanced option: {best_overall['model']} + {best_overall['prompt']}")
    print(f"   Combined score: {best_overall['combined_score']:.3f}")
    print(f"   Accuracy: {best_overall['both_correct_accuracy']:.3f}")
    print(f"   Speed: {best_overall['avg_inference_time_s']:.2f}s") 
    print(f"   Cost: ${best_overall['cost_per_sample']:.4f}")
    
    print(f"\n📋 NEXT STEPS:")
    print(f"1. Run full evaluation (set TEST_MODE=False) on complete dataset")
    print(f"2. Test additional models: Gemini Pro, other providers")
    print(f"3. Experiment with prompt variations and video settings")
    print(f"4. Implement A/B testing in production")
    print(f"5. Set up continuous evaluation pipeline")

else:
    print("❌ No results available for recommendations")

print(f"\n✅ Model comparison complete!")
print(f"📁 Results saved to: ../data/model_outputs/")
print(f"📊 Run this notebook again with TEST_MODE=False for full evaluation")
