<a href="https://colab.research.google.com/github/dimitarpg13/transformer_examples/blob/main/notebooks/sentence_transformers/bi_encoder_vs_cross_encoder_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bi-Encoders vs Cross-Encoders: Comprehensive Comparison and Use Cases

## Overview

This notebook provides a detailed comparison between bi-encoders and cross-encoders, demonstrating:
- Architectural differences and computational trade-offs
- Practical implementations using Sentence Transformers
- Contrastive examples showing optimal use cases for each
- Performance benchmarking and hybrid approaches
- Production deployment considerations

### Key Differences

**Bi-Encoders:**
- Encode queries and documents independently
- Enable pre-computation and caching of embeddings
- Fast similarity search via vector operations
- Suitable for large-scale retrieval

**Cross-Encoders:**
- Process query-document pairs jointly
- Cannot pre-compute representations
- More accurate relevance scoring
- Suitable for re-ranking small candidate sets

In [None]:
# Installation
!pip install -q sentence-transformers faiss-cpu numpy pandas matplotlib seaborn tqdm
!pip install -q datasets torch torchvision scikit-learn

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Dict, Optional, Union
import time
import json
from dataclasses import dataclass, asdict
import warnings
warnings.filterwarnings('ignore')

from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
import faiss
from tqdm.auto import tqdm
from sklearn.metrics import ndcg_score, average_precision_score
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Bi-Encoder Implementation

Bi-encoders encode queries and documents separately, enabling efficient similarity search through vector operations.

In [None]:
@dataclass
class BiEncoderConfig:
    """Configuration for bi-encoder models"""
    model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'
    embedding_dim: int = 384
    max_seq_length: int = 256
    normalize_embeddings: bool = True
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size: int = 32

class BiEncoderRetriever:
    """Production-ready bi-encoder retriever with FAISS indexing"""

    def __init__(self, config: BiEncoderConfig):
        self.config = config
        self.model = SentenceTransformer(config.model_name)
        self.model.max_seq_length = config.max_seq_length
        self.index = None
        self.documents = []
        self.embeddings = None

        logger.info(f"Initialized bi-encoder: {config.model_name}")

    def encode_documents(self, documents: List[str], show_progress: bool = True) -> np.ndarray:
        """Encode documents into embeddings"""
        logger.info(f"Encoding {len(documents)} documents...")

        self.documents = documents
        self.embeddings = self.model.encode(
            documents,
            convert_to_tensor=False,
            batch_size=self.config.batch_size,
            show_progress_bar=show_progress,
            normalize_embeddings=self.config.normalize_embeddings
        )

        # Build FAISS index
        self._build_index()

        return self.embeddings

    def _build_index(self):
        """Build FAISS index for efficient similarity search"""
        if self.config.normalize_embeddings:
            # Use Inner Product for normalized vectors (equivalent to cosine similarity)
            self.index = faiss.IndexFlatIP(self.config.embedding_dim)
        else:
            # Use L2 distance for non-normalized vectors
            self.index = faiss.IndexFlatL2(self.config.embedding_dim)

        self.index.add(self.embeddings.astype('float32'))
        logger.info(f"Built FAISS index with {self.index.ntotal} vectors")

    def search(self, queries: Union[str, List[str]], top_k: int = 10) -> List[Dict]:
        """Search for most similar documents"""
        if isinstance(queries, str):
            queries = [queries]

        # Encode queries
        query_embeddings = self.model.encode(
            queries,
            convert_to_tensor=False,
            normalize_embeddings=self.config.normalize_embeddings
        )

        # Search
        scores, indices = self.index.search(query_embeddings.astype('float32'), top_k)

        results = []
        for q_idx, query in enumerate(queries):
            query_results = {
                'query': query,
                'results': [
                    {
                        'document': self.documents[idx],
                        'score': float(score),
                        'rank': rank + 1
                    }
                    for rank, (idx, score) in enumerate(zip(indices[q_idx], scores[q_idx]))
                    if idx != -1  # FAISS returns -1 for empty results
                ]
            }
            results.append(query_results)

        return results[0] if len(queries) == 1 else results

## 2. Cross-Encoder Implementation

Cross-encoders process query-document pairs jointly for more accurate relevance scoring.

In [None]:
@dataclass
class CrossEncoderConfig:
    """Configuration for cross-encoder models"""
    model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
    max_length: int = 512
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size: int = 16
    activation_fct: str = 'sigmoid'  # or 'none' for raw scores

class CrossEncoderReranker:
    """Production-ready cross-encoder reranker"""

    def __init__(self, config: CrossEncoderConfig):
        self.config = config
        self.model = CrossEncoder(
            config.model_name,
            max_length=config.max_length,
            device=config.device
        )
        logger.info(f"Initialized cross-encoder: {config.model_name}")

    def rerank(self, query: str, documents: List[str],
               initial_scores: Optional[List[float]] = None) -> List[Dict]:
        """Rerank documents for a query"""

        # Create query-document pairs
        pairs = [[query, doc] for doc in documents]

        # Get cross-encoder scores
        ce_scores = self.model.predict(
            pairs,
            batch_size=self.config.batch_size,
            show_progress_bar=False,
            activation_fct=self.config.activation_fct
        )

        # Combine with initial scores if provided (optional hybrid scoring)
        if initial_scores is not None:
            # Normalize initial scores to [0, 1] range
            min_score = min(initial_scores)
            max_score = max(initial_scores)
            if max_score > min_score:
                norm_initial = [(s - min_score) / (max_score - min_score) for s in initial_scores]
            else:
                norm_initial = initial_scores

            # Weighted combination (can be tuned)
            alpha = 0.7  # Weight for cross-encoder scores
            final_scores = [
                alpha * ce_score + (1 - alpha) * init_score
                for ce_score, init_score in zip(ce_scores, norm_initial)
            ]
        else:
            final_scores = ce_scores

        # Sort by scores
        results = [
            {
                'document': doc,
                'cross_encoder_score': float(ce_score),
                'final_score': float(final_score),
                'rank': rank + 1
            }
            for rank, (doc, ce_score, final_score) in enumerate(
                sorted(zip(documents, ce_scores, final_scores),
                       key=lambda x: x[2], reverse=True)
            )
        ]

        return results

    def batch_rerank(self, queries: List[str], documents_list: List[List[str]]) -> List[List[Dict]]:
        """Rerank multiple queries efficiently"""
        results = []
        for query, documents in tqdm(zip(queries, documents_list),
                                     total=len(queries), desc="Reranking"):
            results.append(self.rerank(query, documents))
        return results

## 3. Contrastive Examples: When to Use Each

Let's create realistic scenarios that highlight the strengths of each approach.

In [None]:
# Create sample datasets for different use cases

# Use Case 1: Large-scale document retrieval (Bi-encoder strength)
technical_documents = [
    "RAG systems combine retrieval mechanisms with generative models to provide grounded responses.",
    "Vector databases like Pinecone and Weaviate enable efficient similarity search at scale.",
    "LangChain provides abstractions for building LLM-powered applications with retrieval.",
    "Embedding models convert text into dense vector representations for semantic search.",
    "Cross-attention mechanisms allow models to attend to external knowledge during generation.",
    "Hybrid search combines dense and sparse retrieval methods for improved accuracy.",
    "Knowledge graphs provide structured representations of entities and relationships.",
    "Fine-tuning embedding models on domain-specific data improves retrieval quality.",
    "Chunking strategies affect the granularity of retrieved information in RAG systems.",
    "Reranking with cross-encoders improves precision at the cost of computational efficiency.",
    "GraphRAG extends traditional RAG with graph-based knowledge representation.",
    "Multi-hop reasoning requires iterative retrieval and reasoning steps.",
    "Semantic caching reduces latency by storing embeddings of frequent queries.",
    "Dense passage retrieval outperforms BM25 on many question-answering benchmarks.",
    "Contrastive learning improves the quality of learned embeddings for retrieval."
]

# Use Case 2: Nuanced similarity (Cross-encoder strength)
nuanced_pairs = [
    ("How do transformers work?", [
        "Transformers use self-attention to process sequences in parallel.",
        "Electrical transformers convert voltage levels in power systems.",
        "The transformer architecture revolutionized NLP in 2017.",
        "Attention mechanisms allow models to focus on relevant parts of input.",
        "BERT and GPT are both based on the transformer architecture."
    ]),
    ("What are the benefits of exercise?", [
        "Regular physical activity improves cardiovascular health and mental wellbeing.",
        "Exercise can be challenging to maintain without proper motivation.",
        "Many people exercise to lose weight and build muscle.",
        "The benefits of exercise include reduced risk of chronic diseases.",
        "Some exercises are better for flexibility while others build strength."
    ])
]

# Use Case 3: FAQ matching (Bi-encoder strength)
faq_documents = [
    "How do I reset my password? Click on 'Forgot Password' on the login page.",
    "What payment methods do you accept? We accept credit cards, PayPal, and bank transfers.",
    "How long does shipping take? Standard shipping takes 5-7 business days.",
    "Can I return an item? Yes, returns are accepted within 30 days of purchase.",
    "How do I track my order? Use the tracking number sent to your email.",
    "Is international shipping available? Yes, we ship to over 50 countries.",
    "How do I contact customer support? Email support@example.com or call 1-800-EXAMPLE.",
    "What is your refund policy? Full refunds are provided for unused items.",
    "Do you offer discounts? Check our promotions page for current offers.",
    "How do I create an account? Click 'Sign Up' and fill in your details."
]

# Use Case 4: Semantic textual similarity (Cross-encoder strength)
semantic_pairs = [
    ("The weather is beautiful today.", "It's such a lovely day outside."),
    ("The weather is beautiful today.", "The forecast shows rain all week."),
    ("I need to buy groceries.", "Time to go shopping for food."),
    ("I need to buy groceries.", "The store closes at 9 PM."),
    ("The meeting was productive.", "We accomplished a lot in the meeting."),
    ("The meeting was productive.", "The meeting lasted two hours."),
]

## 4. Performance Comparison: Speed vs Accuracy Trade-offs

In [None]:
class PerformanceBenchmark:
    """Compare bi-encoder and cross-encoder performance"""

    def __init__(self):
        self.bi_encoder = BiEncoderRetriever(BiEncoderConfig())
        self.cross_encoder = CrossEncoderReranker(CrossEncoderConfig())
        self.results = {}

    def benchmark_retrieval_speed(self, documents: List[str], queries: List[str], top_k: int = 10):
        """Compare retrieval speed"""
        print("\n" + "="*50)
        print("SPEED BENCHMARK: Retrieval Performance")
        print("="*50)

        # Bi-encoder: Encoding + Search
        print("\n1. Bi-Encoder Performance:")

        # Document encoding (one-time cost)
        start_time = time.time()
        self.bi_encoder.encode_documents(documents, show_progress=False)
        encoding_time = time.time() - start_time
        print(f"   Document encoding time: {encoding_time:.3f}s for {len(documents)} docs")
        print(f"   Average per document: {encoding_time/len(documents)*1000:.2f}ms")

        # Query search (recurring cost)
        start_time = time.time()
        bi_results = [self.bi_encoder.search(q, top_k=top_k) for q in queries]
        search_time = time.time() - start_time
        print(f"   Query search time: {search_time:.3f}s for {len(queries)} queries")
        print(f"   Average per query: {search_time/len(queries)*1000:.2f}ms")

        # Cross-encoder: Direct scoring
        print("\n2. Cross-Encoder Performance:")

        start_time = time.time()
        ce_results = []
        for query in queries:
            # Score all documents (no pre-computation possible)
            ce_results.append(self.cross_encoder.rerank(query, documents)[:top_k])
        ce_time = time.time() - start_time
        print(f"   Total scoring time: {ce_time:.3f}s for {len(queries)} queries")
        print(f"   Average per query: {ce_time/len(queries)*1000:.2f}ms")

        # Comparison
        print("\n3. Performance Comparison:")
        print(f"   Bi-encoder (excluding encoding): {search_time/len(queries)*1000:.2f}ms per query")
        print(f"   Cross-encoder: {ce_time/len(queries)*1000:.2f}ms per query")
        print(f"   Speed ratio: {ce_time/search_time:.1f}x slower")

        # Store results
        self.results['speed'] = {
            'bi_encoder': {'encoding': encoding_time, 'search': search_time},
            'cross_encoder': {'total': ce_time},
            'queries': len(queries),
            'documents': len(documents)
        }

        return bi_results, ce_results

    def benchmark_accuracy(self, test_pairs: List[Tuple[str, str]],
                          negative_pairs: List[Tuple[str, str]]):
        """Compare accuracy on semantic similarity task"""
        print("\n" + "="*50)
        print("ACCURACY BENCHMARK: Semantic Similarity")
        print("="*50)

        all_pairs = test_pairs + negative_pairs
        labels = [1] * len(test_pairs) + [0] * len(negative_pairs)

        # Bi-encoder scores
        bi_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        embeddings1 = bi_model.encode([p[0] for p in all_pairs])
        embeddings2 = bi_model.encode([p[1] for p in all_pairs])
        bi_scores = util.cos_sim(embeddings1, embeddings2).diagonal().numpy()

        # Cross-encoder scores
        ce_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        ce_scores = ce_model.predict(all_pairs)

        # Calculate metrics
        from sklearn.metrics import roc_auc_score, accuracy_score

        bi_auc = roc_auc_score(labels, bi_scores)
        ce_auc = roc_auc_score(labels, ce_scores)

        print(f"\n1. ROC-AUC Scores:")
        print(f"   Bi-encoder: {bi_auc:.3f}")
        print(f"   Cross-encoder: {ce_auc:.3f}")
        print(f"   Improvement: {(ce_auc - bi_auc) * 100:.1f}%")

        # Threshold-based accuracy
        bi_preds = (bi_scores > 0.5).astype(int)
        ce_preds = (ce_scores > 0.5).astype(int)

        bi_acc = accuracy_score(labels, bi_preds)
        ce_acc = accuracy_score(labels, ce_preds)

        print(f"\n2. Binary Classification Accuracy:")
        print(f"   Bi-encoder: {bi_acc:.3f}")
        print(f"   Cross-encoder: {ce_acc:.3f}")

        self.results['accuracy'] = {
            'bi_encoder': {'auc': bi_auc, 'accuracy': bi_acc},
            'cross_encoder': {'auc': ce_auc, 'accuracy': ce_acc}
        }

        return bi_scores, ce_scores, labels

# Run benchmarks
benchmark = PerformanceBenchmark()

# Speed benchmark
test_queries = [
    "How does RAG improve LLM responses?",
    "What is semantic search?",
    "Explain vector databases"
]

bi_results, ce_results = benchmark.benchmark_retrieval_speed(
    technical_documents,
    test_queries,
    top_k=5
)

## 5. Hybrid Approach: Two-Stage Retrieval Pipeline

Combine bi-encoders for initial retrieval with cross-encoders for reranking.

In [None]:
class HybridRetriever:
    """Two-stage retrieval: Bi-encoder retrieval + Cross-encoder reranking"""

    def __init__(self, bi_encoder_config: BiEncoderConfig = None,
                 cross_encoder_config: CrossEncoderConfig = None):
        self.bi_encoder = BiEncoderRetriever(bi_encoder_config or BiEncoderConfig())
        self.cross_encoder = CrossEncoderReranker(cross_encoder_config or CrossEncoderConfig())
        self.metrics = {'retrieval_time': [], 'rerank_time': [], 'total_time': []}

    def index_documents(self, documents: List[str]):
        """Index documents for retrieval"""
        self.bi_encoder.encode_documents(documents)
        logger.info(f"Indexed {len(documents)} documents")

    def retrieve(self, query: str, initial_top_k: int = 100, final_top_k: int = 10,
                 rerank_top_n: Optional[int] = None) -> Dict:
        """
        Two-stage retrieval pipeline

        Args:
            query: Search query
            initial_top_k: Number of candidates from bi-encoder
            final_top_k: Final number of results after reranking
            rerank_top_n: Number of top candidates to rerank (default: all)
        """
        total_start = time.time()

        # Stage 1: Fast retrieval with bi-encoder
        retrieval_start = time.time()
        initial_results = self.bi_encoder.search(query, top_k=initial_top_k)
        retrieval_time = time.time() - retrieval_start

        # Extract candidates for reranking
        candidates = initial_results['results']
        if rerank_top_n:
            candidates = candidates[:rerank_top_n]

        candidate_docs = [r['document'] for r in candidates]
        initial_scores = [r['score'] for r in candidates]

        # Stage 2: Accurate reranking with cross-encoder
        rerank_start = time.time()
        reranked = self.cross_encoder.rerank(query, candidate_docs, initial_scores)
        rerank_time = time.time() - rerank_start

        # Select top-k after reranking
        final_results = reranked[:final_top_k]

        total_time = time.time() - total_start

        # Store metrics
        self.metrics['retrieval_time'].append(retrieval_time)
        self.metrics['rerank_time'].append(rerank_time)
        self.metrics['total_time'].append(total_time)

        return {
            'query': query,
            'results': final_results,
            'metrics': {
                'retrieval_time': retrieval_time,
                'rerank_time': rerank_time,
                'total_time': total_time,
                'initial_candidates': len(candidates),
                'final_results': len(final_results)
            }
        }

    def evaluate_pipeline(self, queries: List[str], relevant_docs: List[List[int]],
                         initial_top_k: int = 100, final_top_k: int = 10):
        """Evaluate the hybrid pipeline with metrics"""
        results = []

        for query, relevant in tqdm(zip(queries, relevant_docs), total=len(queries)):
            result = self.retrieve(query, initial_top_k, final_top_k)

            # Calculate metrics
            retrieved_indices = [self.bi_encoder.documents.index(r['document'])
                               for r in result['results']
                               if r['document'] in self.bi_encoder.documents]

            # Precision@k
            hits = sum(1 for idx in retrieved_indices[:final_top_k] if idx in relevant)
            precision = hits / min(final_top_k, len(retrieved_indices))

            # Recall@k
            recall = hits / len(relevant) if relevant else 0

            result['evaluation'] = {
                'precision': precision,
                'recall': recall,
                'f1': 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            }
            results.append(result)

        # Aggregate metrics
        avg_metrics = {
            'avg_precision': np.mean([r['evaluation']['precision'] for r in results]),
            'avg_recall': np.mean([r['evaluation']['recall'] for r in results]),
            'avg_f1': np.mean([r['evaluation']['f1'] for r in results]),
            'avg_retrieval_time': np.mean(self.metrics['retrieval_time']),
            'avg_rerank_time': np.mean(self.metrics['rerank_time']),
            'avg_total_time': np.mean(self.metrics['total_time'])
        }

        return results, avg_metrics

## 6. Practical Use Case Demonstrations

In [None]:
def demonstrate_use_cases():
    """Demonstrate when to use each approach"""

    print("\n" + "="*70)
    print("USE CASE DEMONSTRATIONS")
    print("="*70)

    # Use Case 1: Large-scale FAQ Search (Bi-encoder optimal)
    print("\n" + "─"*70)
    print("USE CASE 1: FAQ Search System (10,000+ FAQs)")
    print("OPTIMAL: Bi-encoder (pre-computed embeddings + fast search)")
    print("─"*70)

    bi_config = BiEncoderConfig(model_name='sentence-transformers/all-mpnet-base-v2')
    faq_retriever = BiEncoderRetriever(bi_config)

    # Simulate large FAQ database
    large_faq_db = faq_documents * 100  # 1000 FAQs
    print(f"\nIndexing {len(large_faq_db)} FAQs...")

    start = time.time()
    faq_retriever.encode_documents(large_faq_db, show_progress=False)
    print(f"Indexing completed in {time.time() - start:.2f}s")

    # Test queries
    test_queries = [
        "How can I get my money back?",
        "I forgot my login credentials",
        "Where is my package?"
    ]

    print("\nQuery results:")
    for query in test_queries:
        start = time.time()
        results = faq_retriever.search(query, top_k=3)
        query_time = (time.time() - start) * 1000

        print(f"\nQuery: '{query}' (Time: {query_time:.1f}ms)")
        for r in results['results'][:2]:
            print(f"  - Score: {r['score']:.3f} | {r['document'][:60]}...")

    # Use Case 2: Legal Document Ranking (Cross-encoder optimal)
    print("\n" + "─"*70)
    print("USE CASE 2: Legal Document Relevance (High Precision Required)")
    print("OPTIMAL: Cross-encoder (nuanced understanding needed)")
    print("─"*70)

    legal_documents = [
        "The defendant breached the contract by failing to deliver goods on the agreed date.",
        "Contract law requires mutual consideration for a valid agreement.",
        "The breach resulted in significant financial damages to the plaintiff.",
        "Force majeure clauses may excuse performance under certain circumstances.",
        "The court found that the contract was unconscionable and therefore unenforceable."
    ]

    legal_query = "Was there a material breach that caused monetary harm?"

    ce_config = CrossEncoderConfig(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2')
    legal_reranker = CrossEncoderReranker(ce_config)

    print(f"\nQuery: '{legal_query}'")
    print("\nCross-encoder relevance scores:")

    results = legal_reranker.rerank(legal_query, legal_documents)
    for r in results:
        print(f"  Score: {r['cross_encoder_score']:.3f} | {r['document'][:70]}...")

    # Use Case 3: E-commerce Search (Hybrid approach optimal)
    print("\n" + "─"*70)
    print("USE CASE 3: E-commerce Product Search (Speed + Relevance)")
    print("OPTIMAL: Hybrid (Bi-encoder retrieval + Cross-encoder reranking)")
    print("─"*70)

    product_catalog = [
        "Sony WH-1000XM4 Wireless Noise Canceling Headphones - Black",
        "Bose QuietComfort 35 II Wireless Bluetooth Headphones",
        "Apple AirPods Pro with Active Noise Cancellation",
        "Beats Studio3 Wireless Over-Ear Headphones - Matte Black",
        "Sennheiser HD 650 Open Back Professional Headphones",
        "Audio-Technica ATH-M50x Professional Studio Monitor Headphones",
        "JBL Tune 750BTNC Wireless Over-Ear Headphones with Noise Cancellation",
        "Anker Soundcore Life Q20 Hybrid Active Noise Cancelling Headphones",
        "Skullcandy Crusher Wireless Over-Ear Headphones with Bass",
        "Plantronics BackBeat Pro 2 Wireless Noise Cancelling Headphones"
    ] * 50  # Simulate 500 products

    hybrid = HybridRetriever()
    hybrid.index_documents(product_catalog)

    search_query = "wireless headphones with best noise cancellation under $300"

    print(f"\nSearch query: '{search_query}'")
    result = hybrid.retrieve(search_query, initial_top_k=20, final_top_k=5)

    print(f"\nPipeline metrics:")
    print(f"  Stage 1 (Bi-encoder): {result['metrics']['retrieval_time']*1000:.1f}ms")
    print(f"  Stage 2 (Cross-encoder reranking): {result['metrics']['rerank_time']*1000:.1f}ms")
    print(f"  Total time: {result['metrics']['total_time']*1000:.1f}ms")

    print(f"\nTop results:")
    for i, r in enumerate(result['results'][:3], 1):
        print(f"  {i}. Score: {r['final_score']:.3f} | {r['document'][:60]}...")

# Run demonstrations
demonstrate_use_cases()

## 7. Decision Framework: Choosing the Right Approach

In [None]:
def create_decision_framework():
    """Create a decision framework for choosing between approaches"""

    framework = pd.DataFrame({
        'Scenario': [
            'Large-scale search (>10K docs)',
            'Real-time search requirements',
            'High precision needed',
            'Complex semantic understanding',
            'FAQ/Knowledge base search',
            'Document similarity at scale',
            'Passage ranking for QA',
            'Legal/Medical document ranking',
            'E-commerce search',
            'Semantic textual similarity',
            'Information retrieval first-stage',
            'Re-ranking top candidates',
            'Clustering/Classification',
            'Zero-shot classification'
        ],
        'Recommended': [
            'Bi-encoder',
            'Bi-encoder',
            'Cross-encoder',
            'Cross-encoder',
            'Bi-encoder',
            'Bi-encoder',
            'Cross-encoder',
            'Cross-encoder',
            'Hybrid',
            'Cross-encoder',
            'Bi-encoder',
            'Cross-encoder',
            'Bi-encoder',
            'Cross-encoder'
        ],
        'Reasoning': [
            'Pre-computed embeddings enable sub-second search',
            'Fast vector similarity search meets latency requirements',
            'Joint encoding captures nuanced relationships',
            'Cross-attention enables deep semantic understanding',
            'Efficient similarity search with cached embeddings',
            'Scalable with approximate nearest neighbor search',
            'Accurate relevance scoring for answer extraction',
            'Precision critical; computational cost acceptable',
            'Balance speed (bi) with relevance (cross)',
            'Direct comparison yields highest accuracy',
            'Efficient candidate generation from large corpus',
            'Improve precision on manageable candidate set',
            'Dense embeddings enable efficient clustering',
            'Accurate classification without training data'
        ]
    })

    print("\n" + "="*100)
    print("DECISION FRAMEWORK: When to Use Each Approach")
    print("="*100)
    print(framework.to_string(index=False))

    # Visualize trade-offs
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    # Speed vs Accuracy trade-off
    approaches = ['Bi-encoder\n(Retrieval)', 'Cross-encoder\n(Reranking)', 'Hybrid\n(Two-stage)']
    speed = [95, 20, 75]  # Relative speed (higher is faster)
    accuracy = [70, 95, 90]  # Relative accuracy

    ax1.scatter(speed, accuracy, s=500, alpha=0.6, c=['blue', 'red', 'green'])
    for i, txt in enumerate(approaches):
        ax1.annotate(txt, (speed[i], accuracy[i]), ha='center', va='center')

    ax1.set_xlabel('Speed (relative)', fontsize=12)
    ax1.set_ylabel('Accuracy (relative)', fontsize=12)
    ax1.set_title('Speed vs Accuracy Trade-off', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(0, 100)
    ax1.set_ylim(60, 100)

    # Scalability comparison
    doc_counts = [100, 1000, 10000, 100000, 1000000]
    bi_times = [0.01, 0.02, 0.05, 0.1, 0.2]  # Query times in seconds
    ce_times = [0.1, 1, 10, 100, 1000]  # Cross-encoder doesn't scale
    hybrid_times = [0.05, 0.1, 0.15, 0.2, 0.3]  # Hybrid approach

    ax2.loglog(doc_counts, bi_times, 'b-o', label='Bi-encoder', linewidth=2)
    ax2.loglog(doc_counts, ce_times, 'r-s', label='Cross-encoder', linewidth=2)
    ax2.loglog(doc_counts, hybrid_times, 'g-^', label='Hybrid', linewidth=2)

    ax2.set_xlabel('Number of Documents', fontsize=12)
    ax2.set_ylabel('Query Time (seconds)', fontsize=12)
    ax2.set_title('Scalability Comparison', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3, which='both')
    ax2.legend(loc='upper left')

    # Add practical latency thresholds
    ax2.axhline(y=0.1, color='orange', linestyle='--', alpha=0.5, label='100ms (good UX)')
    ax2.axhline(y=1, color='red', linestyle='--', alpha=0.5, label='1s (acceptable)')

    plt.tight_layout()
    plt.show()

    return framework

decision_framework = create_decision_framework()

## 8. Production Deployment Considerations

In [None]:
class ProductionGuidelines:
    """Production deployment guidelines and monitoring"""

    @staticmethod
    def print_deployment_checklist():
        """Print deployment considerations for each approach"""

        print("\n" + "="*80)
        print("PRODUCTION DEPLOYMENT GUIDELINES")
        print("="*80)

        print("\n1. BI-ENCODER DEPLOYMENT")
        print("─" * 40)
        print("""
        Infrastructure Requirements:
        - Vector database (Pinecone, Weaviate, Milvus, or FAISS)
        - GPU for encoding (optional but recommended)
        - Caching layer for frequent queries

        Optimization Strategies:
        - Batch encoding for document updates
        - Approximate nearest neighbor (ANN) for large scale
        - Quantization to reduce memory footprint
        - Asynchronous embedding updates

        Monitoring Metrics:
        - Encoding latency (P50, P95, P99)
        - Search latency
        - Index size and memory usage
        - Cache hit rate
        """)

        print("\n2. CROSS-ENCODER DEPLOYMENT")
        print("─" * 40)
        print("""
        Infrastructure Requirements:
        - GPU inference servers (strongly recommended)
        - Model serving framework (TorchServe, Triton)
        - Request batching system

        Optimization Strategies:
        - Dynamic batching for concurrent requests
        - Model quantization (INT8/FP16)
        - Result caching for repeated queries
        - Distributed inference for high load

        Monitoring Metrics:
        - Inference latency per batch size
        - GPU utilization
        - Queue depth and wait times
        - Model load distribution
        """)

        print("\n3. HYBRID PIPELINE DEPLOYMENT")
        print("─" * 40)
        print("""
        Infrastructure Requirements:
        - Vector database for bi-encoder
        - GPU inference for cross-encoder
        - Orchestration layer (e.g., Ray, Celery)
        - Results caching infrastructure

        Optimization Strategies:
        - Adaptive candidate selection based on query complexity
        - Parallel processing of reranking batches
        - Progressive result streaming
        - Smart caching of intermediate results

        Monitoring Metrics:
        - End-to-end latency breakdown
        - Candidate set size distribution
        - Reranking impact on relevance
        - Resource utilization per stage
        """)

    @staticmethod
    def generate_monitoring_dashboard():
        """Generate sample monitoring metrics"""

        # Simulate metrics over time
        hours = np.arange(24)

        # Bi-encoder metrics
        bi_latency = 20 + 5 * np.sin(hours * np.pi / 12) + np.random.normal(0, 2, 24)
        bi_qps = 1000 + 300 * np.sin((hours - 6) * np.pi / 12) + np.random.normal(0, 50, 24)

        # Cross-encoder metrics
        ce_latency = 200 + 50 * np.sin(hours * np.pi / 12) + np.random.normal(0, 20, 24)
        ce_qps = 100 + 30 * np.sin((hours - 6) * np.pi / 12) + np.random.normal(0, 10, 24)

        # Create dashboard
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))

        # Latency comparison
        ax1.plot(hours, bi_latency, 'b-', label='Bi-encoder', linewidth=2)
        ax1.plot(hours, ce_latency, 'r-', label='Cross-encoder', linewidth=2)
        ax1.set_xlabel('Hour of Day')
        ax1.set_ylabel('Latency (ms)')
        ax1.set_title('Query Latency Over Time', fontweight='bold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # QPS comparison
        ax2.plot(hours, bi_qps, 'b-', label='Bi-encoder', linewidth=2)
        ax2.plot(hours, ce_qps, 'r-', label='Cross-encoder', linewidth=2)
        ax2.set_xlabel('Hour of Day')
        ax2.set_ylabel('Queries Per Second')
        ax2.set_title('Throughput Over Time', fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # Cost efficiency
        doc_counts = [1000, 10000, 100000, 1000000]
        bi_cost = [0.01, 0.05, 0.2, 0.5]  # Dollars per 1000 queries
        ce_cost = [0.1, 1, 10, 100]
        hybrid_cost = [0.03, 0.15, 0.4, 0.8]

        ax3.loglog(doc_counts, bi_cost, 'b-o', label='Bi-encoder', linewidth=2)
        ax3.loglog(doc_counts, ce_cost, 'r-s', label='Cross-encoder', linewidth=2)
        ax3.loglog(doc_counts, hybrid_cost, 'g-^', label='Hybrid', linewidth=2)
        ax3.set_xlabel('Corpus Size')
        ax3.set_ylabel('Cost per 1000 Queries ($)')
        ax3.set_title('Cost Efficiency Analysis', fontweight='bold')
        ax3.legend()
        ax3.grid(True, alpha=0.3, which='both')

        # Accuracy vs latency trade-off for different configurations
        configs = [
            ('Bi-encoder\nOnly', 20, 0.75),
            ('Cross-encoder\nOnly', 200, 0.92),
            ('Hybrid\nTop-100', 40, 0.88),
            ('Hybrid\nTop-50', 35, 0.86),
            ('Hybrid\nTop-20', 30, 0.83),
        ]

        for config, latency, accuracy in configs:
            ax4.scatter(latency, accuracy, s=200, alpha=0.7)
            ax4.annotate(config, (latency, accuracy), ha='center', fontsize=9)

        ax4.set_xlabel('Latency (ms)')
        ax4.set_ylabel('Accuracy (nDCG@10)')
        ax4.set_title('Configuration Trade-offs', fontweight='bold')
        ax4.grid(True, alpha=0.3)

        plt.suptitle('Production Monitoring Dashboard', fontsize=16, fontweight='bold', y=1.02)
        plt.tight_layout()
        plt.show()

# Generate production guidelines
guidelines = ProductionGuidelines()
guidelines.print_deployment_checklist()
guidelines.generate_monitoring_dashboard()

## 9. Advanced Techniques and Best Practices

In [None]:
# Advanced configuration examples
print("\n" + "="*80)
print("ADVANCED CONFIGURATION EXAMPLES")
print("="*80)

print("""
1. DOMAIN-SPECIFIC MODEL SELECTION
───────────────────────────────────

Bi-Encoders:
• General: 'sentence-transformers/all-mpnet-base-v2' (Best quality)
• Fast: 'sentence-transformers/all-MiniLM-L6-v2' (5x faster, 95% quality)
• Multilingual: 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
• Scientific: 'allenai/specter2' (Scientific papers)
• Code: 'microsoft/codebert-base' (Code search)

Cross-Encoders:
• General: 'cross-encoder/ms-marco-MiniLM-L-12-v2' (Best balance)
• High Accuracy: 'cross-encoder/ms-marco-electra-base' (Slower but accurate)
• Fast: 'cross-encoder/ms-marco-TinyBERT-L-2-v2' (3x faster)
• Multilingual: 'cross-encoder/mmarco-mMiniLMv2-L12-H384-v1'

2. OPTIMIZATION TECHNIQUES
──────────────────────────

Bi-Encoder Optimizations:
• Use product quantization for large indexes (8x memory reduction)
• Implement IVF (Inverted File) indexing for billion-scale search
• Apply dimensionality reduction (PCA/UMAP) for faster search
• Cache frequently accessed embeddings in Redis/Memcached

Cross-Encoder Optimizations:
• Quantize models to INT8 (2-4x speedup, <1% accuracy loss)
• Use ONNX Runtime for optimized inference
• Implement dynamic batching with padding
• Deploy multiple model replicas with load balancing

3. HYBRID PIPELINE TUNING
─────────────────────────

Adaptive Strategies:
• Adjust initial_top_k based on query complexity (10-200)
• Use query classification to skip reranking for simple queries
• Implement cascade reranking with multiple cross-encoders
• Cache reranking results for popular query patterns

4. EVALUATION METRICS
─────────────────────

Retrieval Metrics:
• MRR (Mean Reciprocal Rank): Position of first relevant result
• nDCG@k: Graded relevance at position k
• MAP (Mean Average Precision): Average precision across queries
• Recall@k: Fraction of relevant documents retrieved

Production Metrics:
• P50/P95/P99 latencies
• Queries per second (QPS)
• GPU/CPU utilization
• Memory consumption
• Cache hit rates
""")

# Summary recommendations
print("\n" + "="*80)
print("SUMMARY: KEY RECOMMENDATIONS")
print("="*80)

recommendations = pd.DataFrame({
    'Use Case': [
        'Semantic Search (1M+ docs)',
        'Question Answering',
        'Duplicate Detection',
        'Document Ranking',
        'Real-time Search',
        'Clustering',
        'Zero-shot Classification'
    ],
    'Approach': [
        'Bi-encoder + FAISS',
        'Hybrid (Bi + Cross)',
        'Bi-encoder',
        'Cross-encoder',
        'Bi-encoder',
        'Bi-encoder',
        'Cross-encoder'
    ],
    'Key Consideration': [
        'Pre-compute embeddings, use ANN index',
        'Balance retrieval recall with ranking precision',
        'Threshold on cosine similarity',
        'Accuracy more important than speed',
        'Latency critical, cache aggressively',
        'Generate embeddings once, cluster offline',
        'Fine-grained classification without training'
    ]
})

print("\n" + recommendations.to_string(index=False))

print("\n" + "="*80)
print("Notebook execution completed successfully!")
print("="*80)