# AIC Video Retrieval System - Training & Reranking

This notebook handles training reranking models to improve search result quality.
It includes both cross-encoder and gradient boosting approaches for reranking.

## Features
- 📊 Training data preparation and validation
- 🤖 Cross-encoder reranker training
- 🌲 Gradient boosting reranker training
- 📈 Training progress monitoring
- 🧪 Model evaluation and comparison
- 💾 Model checkpoint management

In [None]:
# Import and setup
import os
import sys
import json
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import time
import pickle

# Set up paths (assuming setup notebook was run)
REPO_NAME = "AIC_FTML_dev"
if Path(f"/content/{REPO_NAME}").exists():
    REPO_DIR = Path(f"/content/{REPO_NAME}")
else:
    REPO_DIR = Path.cwd()
    while REPO_DIR.name != REPO_NAME and REPO_DIR.parent != REPO_DIR:
        REPO_DIR = REPO_DIR.parent

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))
sys.path.insert(0, str(REPO_DIR / "src"))

print(f"Working from: {REPO_DIR}")

# Import ML libraries
try:
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader, Dataset
    from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer
    from sentence_transformers import SentenceTransformer, CrossEncoder
    from sentence_transformers.cross_encoder import CrossEncoder
    from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
    from sklearn.ensemble import GradientBoostingRegressor
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import ndcg_score, mean_squared_error
    print("✅ ML libraries loaded successfully")
except ImportError as e:
    print(f"⚠️ Some ML libraries missing: {e}")
    print("Installing missing dependencies...")
    !pip install sentence-transformers scikit-learn xgboost

# Import project modules
import config
from src.models.clip_encoder import CLIPEncoder
from src.pipeline.query_pipeline import QueryProcessingPipeline

## Step 1: Load and Prepare Training Data

In [None]:
# Load training data
print("=== Training Data Preparation ===")

# Check for training data
training_files = [
    Path("data/train.jsonl"),
    Path("data/training_data.jsonl"),
    Path("train.jsonl")
]

training_file = None
for file_path in training_files:
    if file_path.exists():
        training_file = file_path
        break

if not training_file:
    print("❌ No training data found. Creating sample training data...")
    
    # Try to create training data using the utility
    try:
        from utils.create_training_data import create_training_data_from_metadata
        
        success = create_training_data_from_metadata(
            dataset_root="./data",
            output_file="data/train.jsonl",
            num_examples=100
        )
        
        if success:
            training_file = Path("data/train.jsonl")
            print("✅ Training data created")
        else:
            print("❌ Could not create training data")
    except Exception as e:
        print(f"❌ Error creating training data: {e}")
else:
    print(f"✅ Found training data: {training_file}")

# Load and analyze training data
training_data = []
if training_file and training_file.exists():
    with open(training_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line.strip())
                training_data.append(data)
            except json.JSONDecodeError:
                continue
    
    print(f"Loaded {len(training_data)} training examples")
    
    if training_data:
        # Analyze training data
        total_positives = sum(len(item.get('positives', [])) for item in training_data)
        unique_queries = len(set(item['query'] for item in training_data))
        
        print(f"Training data statistics:")
        print(f"  Total queries: {len(training_data)}")
        print(f"  Unique queries: {unique_queries}")
        print(f"  Total positive examples: {total_positives}")
        print(f"  Avg positives per query: {total_positives/len(training_data):.1f}")
        
        # Show sample
        print("\n📋 Sample training examples:")
        for i, item in enumerate(training_data[:3]):
            print(f"{i+1}. Query: '{item['query']}'")
            print(f"   Positives: {len(item.get('positives', []))} frames")
            if item.get('positives'):
                sample_pos = item['positives'][0]
                print(f"   Sample: {sample_pos['video_id']} frame {sample_pos['frame_idx']}")
else:
    print("❌ No training data available. Please run data processing notebook first.")
    training_data = []

## Step 2: Create Training Pairs for Reranking

In [None]:
# Create training pairs for reranking
print("=== Creating Training Pairs ===")

def create_reranking_pairs(training_data, negative_sampling_ratio=2):
    """Create query-document pairs with relevance scores for reranking training"""
    
    if not training_data:
        return [], []
    
    # Load index metadata to get negative samples
    ARTIFACT_DIR = Path(config.ARTIFACT_DIR)
    metadata_file = ARTIFACT_DIR / "index_metadata.parquet"
    
    if not metadata_file.exists():
        print("❌ Index metadata not found")
        return [], []
    
    metadata_df = pd.read_parquet(metadata_file)
    all_frames = [(row['video_id'], row['frame_idx']) for _, row in metadata_df.iterrows()]
    
    pairs = []  # (query, video_id_frame_idx)
    labels = []  # relevance scores (1 for positive, 0 for negative)
    
    for item in tqdm(training_data, desc="Creating pairs"):
        query = item['query']
        positives = item.get('positives', [])
        
        if not positives:
            continue
        
        # Add positive pairs
        for pos in positives:
            doc_id = f"{pos['video_id']}_frame_{pos['frame_idx']}"
            pairs.append((query, doc_id))
            labels.append(1)  # Positive label
        
        # Sample negative examples
        positive_frames = {(pos['video_id'], pos['frame_idx']) for pos in positives}
        negative_candidates = [frame for frame in all_frames if frame not in positive_frames]
        
        if negative_candidates:
            num_negatives = min(len(positives) * negative_sampling_ratio, len(negative_candidates))
            negative_samples = np.random.choice(
                len(negative_candidates), 
                size=num_negatives, 
                replace=False
            )
            
            for neg_idx in negative_samples:
                neg_frame = negative_candidates[neg_idx]
                doc_id = f"{neg_frame[0]}_frame_{neg_frame[1]}"
                pairs.append((query, doc_id))
                labels.append(0)  # Negative label
    
    return pairs, labels

# Create training pairs
if training_data:
    print("Creating query-document pairs...")
    train_pairs, train_labels = create_reranking_pairs(training_data)
    
    if train_pairs:
        print(f"✅ Created {len(train_pairs)} training pairs")
        print(f"   Positive pairs: {sum(train_labels)}")
        print(f"   Negative pairs: {len(train_labels) - sum(train_labels)}")
        print(f"   Positive ratio: {sum(train_labels)/len(train_labels):.2%}")
        
        # Split training data
        train_pairs_train, train_pairs_val, train_labels_train, train_labels_val = train_test_split(
            train_pairs, train_labels, test_size=0.2, random_state=42, stratify=train_labels
        )
        
        print(f"\nSplit into:")
        print(f"  Training: {len(train_pairs_train)} pairs")
        print(f"  Validation: {len(train_pairs_val)} pairs")
    else:
        print("❌ No training pairs created")
        train_pairs_train = train_pairs_val = []
        train_labels_train = train_labels_val = []
else:
    print("⚠️ No training data available for pair creation")
    train_pairs_train = train_pairs_val = []
    train_labels_train = train_labels_val = []

## Step 3: Cross-Encoder Reranker Training

In [None]:
# Train Cross-Encoder reranker
print("=== Cross-Encoder Reranker Training ===")

TRAIN_CROSS_ENCODER = True  # Set to False to skip

if TRAIN_CROSS_ENCODER and train_pairs_train:
    try:
        # Initialize cross-encoder
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        # Choose base model
        cross_encoder_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # Lightweight model
        
        print(f"Initializing cross-encoder: {cross_encoder_model}")
        cross_encoder = CrossEncoder(cross_encoder_model, num_labels=1)
        
        # Prepare training data for cross-encoder
        train_samples = []
        for (query, doc_id), label in zip(train_pairs_train, train_labels_train):
            train_samples.append([query, doc_id, label])
        
        val_samples = []
        for (query, doc_id), label in zip(train_pairs_val, train_labels_val):
            val_samples.append([query, doc_id, label])
        
        print(f"Prepared {len(train_samples)} training samples")
        print(f"Prepared {len(val_samples)} validation samples")
        
        # Training configuration
        train_batch_size = 16
        num_epochs = 3
        warmup_steps = 100
        
        print(f"\nTraining configuration:")
        print(f"  Batch size: {train_batch_size}")
        print(f"  Epochs: {num_epochs}")
        print(f"  Warmup steps: {warmup_steps}")
        
        # Setup evaluator
        evaluator = CERerankingEvaluator(val_samples, name='validation')
        
        # Start training
        print("\n🏋️ Starting cross-encoder training...")
        
        output_path = Path("./artifacts/cross_encoder_reranker")
        output_path.mkdir(parents=True, exist_ok=True)
        
        cross_encoder.fit(
            train_dataloader=train_samples,
            evaluator=evaluator,
            epochs=num_epochs,
            evaluation_steps=500,
            warmup_steps=warmup_steps,
            output_path=str(output_path),
            save_best_model=True,
            optimizer_params={'lr': 2e-5},
        )
        
        print("✅ Cross-encoder training completed!")
        print(f"Model saved to: {output_path}")
        
        # Test the trained model
        test_query = "news anchor speaking"
        test_docs = ["video1_frame_100", "video2_frame_200", "video3_frame_300"]
        
        test_pairs = [[test_query, doc] for doc in test_docs]
        scores = cross_encoder.predict(test_pairs)
        
        print(f"\n🧪 Test prediction:")
        print(f"Query: '{test_query}'")
        for doc, score in zip(test_docs, scores):
            print(f"  {doc}: {score:.4f}")
            
    except Exception as e:
        print(f"❌ Cross-encoder training failed: {e}")
        print("This might be due to insufficient data or memory constraints")
else:
    if not TRAIN_CROSS_ENCODER:
        print("⚠️ Cross-encoder training skipped")
    else:
        print("⚠️ Cross-encoder training not possible - no training data")

## Step 4: Gradient Boosting Reranker Training

In [None]:
# Train Gradient Boosting reranker
print("=== Gradient Boosting Reranker Training ===")

TRAIN_GBM = True  # Set to False to skip

def extract_features_for_gbm(query_doc_pairs, clip_encoder=None):
    """Extract features for gradient boosting model"""
    features = []
    
    for query, doc_id in tqdm(query_doc_pairs, desc="Extracting features"):
        # Basic text features
        feature_vector = [
            len(query),  # Query length
            len(query.split()),  # Query word count
            len(doc_id),  # Document ID length
            1 if 'frame' in doc_id else 0,  # Has 'frame' in doc ID
            query.count(' '),  # Space count in query
        ]
        
        # Text overlap features
        query_words = set(query.lower().split())
        doc_words = set(doc_id.lower().replace('_', ' ').split())
        
        if query_words:
            overlap_ratio = len(query_words & doc_words) / len(query_words)
        else:
            overlap_ratio = 0
        
        feature_vector.extend([
            overlap_ratio,  # Word overlap ratio
            len(query_words & doc_words),  # Overlap count
        ])
        
        # Query type features (simple heuristics)
        feature_vector.extend([
            1 if any(word in query.lower() for word in ['news', 'anchor', 'reporter']) else 0,
            1 if any(word in query.lower() for word in ['speaking', 'talking', 'presenting']) else 0,
            1 if any(word in query.lower() for word in ['studio', 'broadcast', 'television']) else 0,
            1 if len([c for c in query if ord(c) > 127]) > 0 else 0,  # Has non-ASCII (Vietnamese)
        ])
        
        features.append(feature_vector)
    
    return np.array(features)

if TRAIN_GBM and train_pairs_train:
    try:
        print("Extracting features for gradient boosting...")
        
        # Extract features
        X_train = extract_features_for_gbm(train_pairs_train)
        y_train = np.array(train_labels_train)
        
        X_val = extract_features_for_gbm(train_pairs_val)
        y_val = np.array(train_labels_val)
        
        print(f"✅ Extracted features:")
        print(f"  Training: {X_train.shape}")
        print(f"  Validation: {X_val.shape}")
        print(f"  Feature dimension: {X_train.shape[1]}")
        
        # Train gradient boosting model
        print("\n🌲 Training gradient boosting reranker...")
        
        gbm_model = GradientBoostingRegressor(
            n_estimators=100,
            max_depth=6,
            learning_rate=0.1,
            subsample=0.8,
            random_state=42,
            verbose=1
        )
        
        gbm_model.fit(X_train, y_train)
        
        # Evaluate model
        train_pred = gbm_model.predict(X_train)
        val_pred = gbm_model.predict(X_val)
        
        train_mse = mean_squared_error(y_train, train_pred)
        val_mse = mean_squared_error(y_val, val_pred)
        
        print(f"\n📊 GBM Training Results:")
        print(f"  Training MSE: {train_mse:.4f}")
        print(f"  Validation MSE: {val_mse:.4f}")
        
        # Feature importance
        feature_names = [
            'query_length', 'query_words', 'doc_length', 'has_frame', 'space_count',
            'overlap_ratio', 'overlap_count', 'is_news', 'is_speaking', 'is_studio', 'is_vietnamese'
        ]
        
        importance_df = pd.DataFrame({
            'feature': feature_names,
            'importance': gbm_model.feature_importances_
        }).sort_values('importance', ascending=False)
        
        print("\n🎯 Feature Importance:")
        display(importance_df)
        
        # Save model
        model_path = Path("./artifacts/gbm_reranker.pkl")
        model_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(model_path, 'wb') as f:
            pickle.dump({
                'model': gbm_model,
                'feature_names': feature_names,
                'training_stats': {
                    'train_mse': train_mse,
                    'val_mse': val_mse,
                    'n_samples': len(X_train)
                }
            }, f)
        
        print(f"✅ GBM model saved to: {model_path}")
        
        # Test prediction
        test_pairs = [
            ("news anchor speaking", "video1_frame_100"),
            ("person talking", "video2_frame_200")
        ]
        
        test_features = extract_features_for_gbm(test_pairs)
        test_scores = gbm_model.predict(test_features)
        
        print(f"\n🧪 GBM Test predictions:")
        for (query, doc), score in zip(test_pairs, test_scores):
            print(f"  '{query}' + '{doc}': {score:.4f}")
            
        # Visualize feature importance
        plt.figure(figsize=(10, 6))
        sns.barplot(data=importance_df, x='importance', y='feature')
        plt.title('Feature Importance in GBM Reranker')
        plt.xlabel('Importance')
        plt.tight_layout()
        plt.show()
            
    except Exception as e:
        print(f"❌ GBM training failed: {e}")
else:
    if not TRAIN_GBM:
        print("⚠️ GBM training skipped")
    else:
        print("⚠️ GBM training not possible - no training data")

## Step 5: Model Evaluation and Comparison

In [None]:
# Evaluate and compare reranking models
print("=== Model Evaluation and Comparison ===")

def evaluate_reranker_performance():
    """Evaluate trained rerankers on validation data"""
    
    if not train_pairs_val:
        print("❌ No validation data available")
        return
    
    results = {}
    
    # Load models if available
    cross_encoder_path = Path("./artifacts/cross_encoder_reranker")
    gbm_model_path = Path("./artifacts/gbm_reranker.pkl")
    
    print(f"Evaluating on {len(train_pairs_val)} validation pairs...")
    
    # Baseline: random ranking
    baseline_scores = np.random.rand(len(train_pairs_val))
    baseline_mse = mean_squared_error(train_labels_val, baseline_scores)
    results['Random Baseline'] = {'mse': baseline_mse, 'scores': baseline_scores}
    
    # Cross-encoder evaluation
    if cross_encoder_path.exists():
        try:
            print("Evaluating cross-encoder...")
            cross_encoder = CrossEncoder(str(cross_encoder_path))
            
            ce_pairs = [[query, doc] for query, doc in train_pairs_val]
            ce_scores = cross_encoder.predict(ce_pairs)
            ce_mse = mean_squared_error(train_labels_val, ce_scores)
            
            results['Cross-Encoder'] = {'mse': ce_mse, 'scores': ce_scores}
            print(f"  Cross-encoder MSE: {ce_mse:.4f}")
        except Exception as e:
            print(f"  Cross-encoder evaluation failed: {e}")
    
    # GBM evaluation
    if gbm_model_path.exists():
        try:
            print("Evaluating GBM reranker...")
            with open(gbm_model_path, 'rb') as f:
                gbm_data = pickle.load(f)
                gbm_model = gbm_data['model']
            
            gbm_features = extract_features_for_gbm(train_pairs_val)
            gbm_scores = gbm_model.predict(gbm_features)
            gbm_mse = mean_squared_error(train_labels_val, gbm_scores)
            
            results['GBM Reranker'] = {'mse': gbm_mse, 'scores': gbm_scores}
            print(f"  GBM MSE: {gbm_mse:.4f}")
        except Exception as e:
            print(f"  GBM evaluation failed: {e}")
    
    # Compare results
    if len(results) > 1:
        print("\n📊 Model Comparison:")
        comparison_data = []
        
        for model_name, model_results in results.items():
            mse = model_results['mse']
            scores = model_results['scores']
            
            # Calculate additional metrics
            correlation = np.corrcoef(train_labels_val, scores)[0, 1] if len(scores) > 1 else 0
            
            comparison_data.append({
                'Model': model_name,
                'MSE': mse,
                'Correlation': correlation,
                'Score Mean': np.mean(scores),
                'Score Std': np.std(scores)
            })
        
        comparison_df = pd.DataFrame(comparison_data).sort_values('MSE')
        display(comparison_df.round(4))
        
        # Visualize score distributions
        plt.figure(figsize=(12, 4))
        
        # Score distributions
        plt.subplot(1, 2, 1)
        for model_name, model_results in results.items():
            plt.hist(model_results['scores'], bins=20, alpha=0.7, label=model_name)
        plt.xlabel('Predicted Scores')
        plt.ylabel('Frequency')
        plt.title('Score Distributions')
        plt.legend()
        
        # MSE comparison
        plt.subplot(1, 2, 2)
        mse_values = [results[name]['mse'] for name in comparison_df['Model']]
        plt.bar(comparison_df['Model'], mse_values)
        plt.ylabel('MSE (lower is better)')
        plt.title('Model Performance Comparison')
        plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.show()
        
        # Find best model
        best_model = comparison_df.iloc[0]['Model']
        print(f"\n🏆 Best performing model: {best_model}")
        print(f"   MSE: {comparison_df.iloc[0]['MSE']:.4f}")
        print(f"   Correlation: {comparison_df.iloc[0]['Correlation']:.4f}")
        
        return comparison_df
    else:
        print("⚠️ Not enough models to compare")
        return None

# Run evaluation
if train_pairs_val:
    evaluation_results = evaluate_reranker_performance()
else:
    print("⚠️ No validation data available for evaluation")

## Step 6: Test Rerankers with Live Search

In [None]:
# Test rerankers with actual search queries
print("=== Live Search Testing with Rerankers ===")

def test_rerankers_on_search(query, k=20):
    """Test rerankers by comparing search results with and without reranking"""
    
    try:
        # Initialize basic search pipeline
        ARTIFACT_DIR = Path(config.ARTIFACT_DIR)
        query_pipeline = QueryProcessingPipeline(
            artifact_dir=ARTIFACT_DIR,
            enable_reranking=False  # Start without reranking
        )
        
        # Get baseline results
        print(f"🔍 Searching for: '{query}'")
        baseline_results = query_pipeline.search(query, k=k*2)  # Get more for reranking
        
        if not baseline_results:
            print("❌ No search results found")
            return
        
        print(f"Got {len(baseline_results)} baseline results")
        
        # Test cross-encoder reranking
        cross_encoder_path = Path("./artifacts/cross_encoder_reranker")
        if cross_encoder_path.exists():
            try:
                print("\n🧠 Testing cross-encoder reranking...")
                cross_encoder = CrossEncoder(str(cross_encoder_path))
                
                # Prepare pairs for reranking
                ce_pairs = []
                for result in baseline_results:
                    doc_id = f"{result.video_id}_frame_{result.frame_idx}"
                    ce_pairs.append([query, doc_id])
                
                # Get reranking scores
                rerank_scores = cross_encoder.predict(ce_pairs)
                
                # Sort by rerank scores
                scored_results = list(zip(baseline_results, rerank_scores))
                scored_results.sort(key=lambda x: x[1], reverse=True)
                
                ce_reranked = [result for result, _ in scored_results[:k]]
                
                print(f"✅ Cross-encoder reranking completed")
                
            except Exception as e:
                print(f"❌ Cross-encoder reranking failed: {e}")
                ce_reranked = None
        else:
            ce_reranked = None
            print("⚠️ Cross-encoder model not found")
        
        # Test GBM reranking
        gbm_model_path = Path("./artifacts/gbm_reranker.pkl")
        if gbm_model_path.exists():
            try:
                print("\n🌲 Testing GBM reranking...")
                with open(gbm_model_path, 'rb') as f:
                    gbm_data = pickle.load(f)
                    gbm_model = gbm_data['model']
                
                # Prepare features for GBM
                gbm_pairs = []
                for result in baseline_results:
                    doc_id = f"{result.video_id}_frame_{result.frame_idx}"
                    gbm_pairs.append((query, doc_id))
                
                gbm_features = extract_features_for_gbm(gbm_pairs)
                gbm_scores = gbm_model.predict(gbm_features)
                
                # Sort by GBM scores
                scored_results = list(zip(baseline_results, gbm_scores))
                scored_results.sort(key=lambda x: x[1], reverse=True)
                
                gbm_reranked = [result for result, _ in scored_results[:k]]
                
                print(f"✅ GBM reranking completed")
                
            except Exception as e:
                print(f"❌ GBM reranking failed: {e}")
                gbm_reranked = None
        else:
            gbm_reranked = None
            print("⚠️ GBM model not found")
        
        # Compare results
        print(f"\n📊 Results Comparison (Top 10):")
        print("-" * 80)
        
        comparison_data = []
        
        for i in range(min(10, len(baseline_results))):
            row = {'Rank': i+1}
            
            # Baseline
            baseline_result = baseline_results[i] if i < len(baseline_results) else None
            if baseline_result:
                row['Baseline'] = f"{baseline_result.video_id}_f{baseline_result.frame_idx}"
                row['Baseline_Score'] = f"{baseline_result.score:.3f}"
            
            # Cross-encoder
            if ce_reranked and i < len(ce_reranked):
                ce_result = ce_reranked[i]
                row['Cross-Encoder'] = f"{ce_result.video_id}_f{ce_result.frame_idx}"
                row['CE_Score'] = f"{ce_result.score:.3f}"
            
            # GBM
            if gbm_reranked and i < len(gbm_reranked):
                gbm_result = gbm_reranked[i]
                row['GBM'] = f"{gbm_result.video_id}_f{gbm_result.frame_idx}"
                row['GBM_Score'] = f"{gbm_result.score:.3f}"
            
            comparison_data.append(row)
        
        comparison_df = pd.DataFrame(comparison_data)
        display(comparison_df)
        
        # Calculate ranking changes
        if ce_reranked or gbm_reranked:
            baseline_order = [f"{r.video_id}_f{r.frame_idx}" for r in baseline_results[:10]]
            
            if ce_reranked:
                ce_order = [f"{r.video_id}_f{r.frame_idx}" for r in ce_reranked[:10]]
                ce_changes = sum(1 for i, item in enumerate(ce_order) if i >= len(baseline_order) or item != baseline_order[i])
                print(f"Cross-encoder ranking changes: {ce_changes}/10")
            
            if gbm_reranked:
                gbm_order = [f"{r.video_id}_f{r.frame_idx}" for r in gbm_reranked[:10]]
                gbm_changes = sum(1 for i, item in enumerate(gbm_order) if i >= len(baseline_order) or item != baseline_order[i])
                print(f"GBM ranking changes: {gbm_changes}/10")
        
    except Exception as e:
        print(f"❌ Live search testing failed: {e}")

# Interactive testing widget
test_query_widget = widgets.Text(
    value='news anchor speaking',
    placeholder='Enter test query...',
    description='Test Query:',
    style={'description_width': 'initial'}
)

test_k_widget = widgets.IntSlider(
    value=10,
    min=5,
    max=50,
    step=5,
    description='Results:',
    style={'description_width': 'initial'}
)

def run_reranker_test(query, k):
    if not query.strip():
        print("Please enter a test query")
        return
    
    test_rerankers_on_search(query, k)

# Create interactive widget
reranker_test_widget = interactive(
    run_reranker_test,
    query=test_query_widget,
    k=test_k_widget
)

print("🎛️ Interactive Reranker Testing:")
display(reranker_test_widget)

## Step 7: Model Deployment Configuration

In [None]:
# Create deployment configuration for trained models
print("=== Model Deployment Configuration ===")

def create_deployment_config():
    """Create configuration file for deploying trained rerankers"""
    
    config_data = {
        'reranking': {
            'enabled': True,
            'models': {},
            'default_model': None,
            'rerank_top_k': 1000,
            'return_top_k': 100
        },
        'training_info': {
            'training_date': time.strftime('%Y-%m-%d %H:%M:%S'),
            'training_examples': len(training_data) if training_data else 0,
            'training_pairs': len(train_pairs_train) if 'train_pairs_train' in locals() else 0
        }
    }
    
    # Check available models
    cross_encoder_path = Path("./artifacts/cross_encoder_reranker")
    gbm_model_path = Path("./artifacts/gbm_reranker.pkl")
    
    if cross_encoder_path.exists():
        config_data['reranking']['models']['cross_encoder'] = {
            'type': 'cross_encoder',
            'path': str(cross_encoder_path),
            'description': 'Neural cross-encoder reranker',
            'batch_size': 32,
            'requires_gpu': False
        }
        
        if not config_data['reranking']['default_model']:
            config_data['reranking']['default_model'] = 'cross_encoder'
    
    if gbm_model_path.exists():
        config_data['reranking']['models']['gbm'] = {
            'type': 'gradient_boosting',
            'path': str(gbm_model_path),
            'description': 'Gradient boosting reranker with handcrafted features',
            'requires_gpu': False,
            'fast_inference': True
        }
        
        if not config_data['reranking']['default_model']:
            config_data['reranking']['default_model'] = 'gbm'
    
    # Save configuration
    config_file = Path("./artifacts/reranking_config.json")
    config_file.parent.mkdir(parents=True, exist_ok=True)
    
    with open(config_file, 'w') as f:
        json.dump(config_data, f, indent=2)
    
    print(f"✅ Deployment configuration saved to: {config_file}")
    print(f"\nConfiguration summary:")
    print(f"  Available models: {len(config_data['reranking']['models'])}")
    print(f"  Default model: {config_data['reranking']['default_model']}")
    
    for model_name, model_config in config_data['reranking']['models'].items():
        print(f"  - {model_name}: {model_config['description']}")
    
    return config_data

# Create training summary
def create_training_summary():
    """Create a summary of training results and recommendations"""
    
    summary = {
        'training_completed': time.strftime('%Y-%m-%d %H:%M:%S'),
        'data_statistics': {
            'training_examples': len(training_data) if training_data else 0,
            'training_pairs': len(train_pairs_train) if 'train_pairs_train' in locals() else 0,
            'validation_pairs': len(train_pairs_val) if 'train_pairs_val' in locals() else 0
        },
        'models_trained': [],
        'recommendations': []
    }
    
    # Check what was trained
    if Path("./artifacts/cross_encoder_reranker").exists():
        summary['models_trained'].append({
            'name': 'cross_encoder',
            'type': 'Neural reranker',
            'path': './artifacts/cross_encoder_reranker'
        })
    
    if Path("./artifacts/gbm_reranker.pkl").exists():
        summary['models_trained'].append({
            'name': 'gbm',
            'type': 'Gradient boosting reranker',
            'path': './artifacts/gbm_reranker.pkl'
        })
    
    # Add recommendations
    if len(summary['models_trained']) == 0:
        summary['recommendations'].append("No models were successfully trained. Check training data availability and system resources.")
    elif len(summary['models_trained']) == 1:
        summary['recommendations'].append(f"Only {summary['models_trained'][0]['name']} was trained. Consider training additional models for comparison.")
    else:
        summary['recommendations'].append("Multiple rerankers trained successfully. Test both in production and choose the best performing one.")
    
    if summary['data_statistics']['training_examples'] < 50:
        summary['recommendations'].append("Training data is limited. Consider generating more training examples for better model performance.")
    
    summary['recommendations'].extend([
        "Test rerankers with diverse queries to ensure robust performance.",
        "Monitor reranking latency in production and adjust batch sizes if needed.",
        "Consider A/B testing to measure reranking impact on user satisfaction."
    ])
    
    # Save summary
    summary_file = Path("./artifacts/training_summary.json")
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\n📋 Training Summary:")
    print(f"  Models trained: {len(summary['models_trained'])}")
    print(f"  Training examples: {summary['data_statistics']['training_examples']}")
    print(f"  Training pairs: {summary['data_statistics']['training_pairs']}")
    
    print(f"\n💡 Recommendations:")
    for i, rec in enumerate(summary['recommendations'], 1):
        print(f"  {i}. {rec}")
    
    print(f"\n✅ Summary saved to: {summary_file}")
    return summary

# Create configurations
deployment_config = create_deployment_config()
training_summary = create_training_summary()

print("\n" + "="*50)
print("🎉 TRAINING & RERANKING COMPLETE!")
print("="*50)
print("\nNext steps:")
print("1. Use 05_end_to_end_pipeline.ipynb to test complete workflow")
print("2. Test reranking performance in production")
print("3. Consider generating more training data for better results")

## Summary & Model Usage

This notebook has trained and evaluated reranking models to improve search quality:

### Models Trained:
1. **Cross-Encoder Reranker**: Neural model that directly scores query-document pairs
2. **Gradient Boosting Reranker**: Feature-based model using handcrafted features

### Generated Artifacts:
- `artifacts/cross_encoder_reranker/` - Trained cross-encoder model
- `artifacts/gbm_reranker.pkl` - Trained GBM model with features
- `artifacts/reranking_config.json` - Deployment configuration
- `artifacts/training_summary.json` - Training results summary

### Usage in Production:
1. Load the reranking configuration
2. Initialize the QueryProcessingPipeline with `enable_reranking=True`
3. The pipeline will automatically use the best available reranker

### Performance Tips:
- Cross-encoder provides better quality but is slower
- GBM is faster and suitable for high-throughput scenarios
- Consider A/B testing to measure real-world impact
- Monitor inference latency and adjust batch sizes accordingly