<a href="https://colab.research.google.com/github/Rongxuan-Zhou/CS6120_project/blob/index_construction-%26-hybrid_retrieval/notebooks/index-construction%20%26%20hybrid-retrieval-2.0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 1. Environment Setup
!pip install -q faiss-cpu sentence-transformers nltk rank-bm25 hnswlib scikit-learn datasets pytrec_eval tqdm
from google.colab import drive
drive.mount('/content/drive')

import os
import sys
import time
import json
import pickle
import numpy as np
import torch
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
import rank_bm25
from collections import Counter
import random
from typing import Dict, List, Tuple, Any, Optional, Union

# Set project path
PROJECT_PATH = "/content/drive/MyDrive/CS6120_project"
os.chdir(PROJECT_PATH)

# Create necessary directories
os.makedirs("models/indexes", exist_ok=True)

# Check GPU availability
print(f"Available GPU: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU for processing")

# Download NLTK resources
import nltk
for resource in ['punkt', 'stopwords']:
    try:
        nltk.data.find(f'tokenizers/{resource}')
    except LookupError:
        nltk.download(resource)

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m71.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m30.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [2]:
# 2. Helper Classes & Functions

class Timer:
    """Simple timer for benchmarking"""
    def __init__(self, name="Operation"):
        self.name = name

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.end = time.time()
        self.interval = self.end - self.start
        print(f"{self.name} completed in {self.interval:.2f} seconds")


def preprocess_text(text: str) -> List[str]:
    """Preprocess text for BM25 indexing"""
    if not isinstance(text, str):
        return []

    # Tokenize
    tokens = word_tokenize(text.lower())

    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [token for token in tokens if token not in stop_words and token.isalnum()]

    return tokens


def normalize_scores(scores: np.ndarray) -> np.ndarray:
    """Normalize scores to [0,1] range with safe handling of edge cases"""
    if len(scores) == 0:
        return scores

    min_val = np.min(scores)
    max_val = np.max(scores)

    if max_val == min_val:
        return np.ones_like(scores)

    return (scores - min_val) / (max_val - min_val + 1e-8)


def load_models(model_paths: Dict[str, str]) -> Tuple[Dict[str, SentenceTransformer], Dict[str, int]]:
    """Load all models and return them with their embedding dimensions"""
    models = {}
    dimensions = {}

    for model_name, model_path in model_paths.items():
        print(f"Loading model: {model_name}")
        try:
            models[model_name] = SentenceTransformer(model_path)
            models[model_name].to(device)
            dimensions[model_name] = models[model_name].get_sentence_embedding_dimension()

            print(f"  - Model path: {model_path}")
            print(f"  - Model architecture: {dimensions[model_name]}d embedding dimension")
            print(f"  - Model details: {models[model_name]}")
            print("")
        except Exception as e:
            print(f"Error loading model {model_name} from {model_path}: {str(e)}")
            continue

    if not models:
        raise ValueError("No models could be loaded. Please check model paths.")

    return models, dimensions

In [3]:
# 3. MS MARCO Data Loading
from datasets import load_dataset

def load_msmarco_data(max_samples: int = 5000, seed: int = 42) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, Dict[str, int]]]:
    """Load MS MARCO dataset with proper format for MRR evaluation

    Args:
        max_samples: Maximum samples to load
        seed: Random seed

    Returns:
        corpus: Dictionary {doc_id: document_text}
        queries: Dictionary {query_id: query_text}
        qrels: Dictionary {query_id: {doc_id: relevance}}
    """
    print("Loading MS MARCO dataset...")
    try:
        dataset = load_dataset("ms_marco", "v1.1")
        dev_data = dataset["validation"].shuffle(seed=seed).select(range(max_samples))
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        raise

    queries = {}
    corpus = {}
    qrels = {}

    # Process each sample
    for example in dev_data:
        # Get query_id and query text
        qid = str(example["query_id"])
        query_text = example["query"]
        queries[qid] = query_text

        # Get passages information
        passages_info = example["passages"]
        passage_texts = passages_info.get("passage_text", [])
        is_selecteds = passages_info.get("is_selected", [])

        # Process each passage
        for i, (text, is_sel) in enumerate(zip(passage_texts, is_selecteds)):
            # Generate unique document ID as "qid_i"
            doc_id = f"{qid}_{i}"
            corpus[doc_id] = text

            # If passage is relevant, add to qrels
            if is_sel == 1:
                if qid not in qrels:
                    qrels[qid] = {}
                qrels[qid][doc_id] = 1

    # Check positive counts
    check_positive_counts(queries, qrels)

    print(f"Loaded {len(corpus)} documents, {len(queries)} queries, {len(qrels)} qrels.")
    return corpus, queries, qrels


def check_positive_counts(queries: Dict[str, str], qrels: Dict[str, Dict[str, int]]):
    """Analyze distribution of relevant documents per query"""
    # Count positive examples per query
    positive_counts = []
    for qid in queries:
        if qid in qrels:
            positive_counts.append(len(qrels[qid]))
        else:
            positive_counts.append(0)

    # Count distribution
    counter = Counter(positive_counts)
    print("Positive examples distribution (count: queries):")
    for num_pos, num_queries in sorted(counter.items()):
        print(f"{num_pos} positive examples: {num_queries} queries")

    # Count queries without positives
    total_queries = len(queries)
    no_positive = counter.get(0, 0)
    print(f"\nTotal queries: {total_queries}")
    print(f"Queries without positives: {no_positive} ({no_positive/total_queries*100:.2f}%)")

In [4]:
# 4. Load Dataset and Models
# Load MS MARCO dataset
try:
    corpus, queries, qrels = load_msmarco_data(max_samples=5000, seed=42)

    # Extract important data structures
    corpus_texts = list(corpus.values())
    doc_ids = list(corpus.keys())

    # Define model paths
    model_paths = {
        "msmarco_stsb": os.path.join(PROJECT_PATH, "model/msmarco_stsb_finetuned_model"),
        "stsb": os.path.join(PROJECT_PATH, "model/stsb_finetuned_model")
    }

    # Load models
    models, dimensions = load_models(model_paths)

    # Models for primary retrieval and fallback
    primary_model = "msmarco_stsb"
    fallback_model = "stsb"

    print(f"Primary model: {primary_model}")
    print(f"Fallback model: {fallback_model}")

except Exception as e:
    print(f"Error during dataset or model loading: {str(e)}")
    print("Please check your environment setup and model paths.")

Loading MS MARCO dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/9.48k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/175M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/10047 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/82326 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9650 [00:00<?, ? examples/s]

Positive examples distribution (count: queries):
0 positive examples: 160 queries
1 positive examples: 4371 queries
2 positive examples: 425 queries
3 positive examples: 38 queries
4 positive examples: 6 queries

Total queries: 5000
Queries without positives: 160 (3.20%)
Loaded 41070 documents, 5000 queries, 4840 qrels.
Loading model: msmarco_stsb
  - Model path: /content/drive/MyDrive/CS6120_project/model/msmarco_stsb_finetuned_model
  - Model architecture: 768d embedding dimension
  - Model details: SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

Loading model: stsb
  - Model path: /content/drive/MyDrive/CS6120_proje

In [7]:
import nltk
nltk.download('punkt_tab')

# 5. Build BM25 and FAISS indexes
# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Create BM25 index with improved error handling
def build_bm25_index(corpus_texts, k1=0.9, b=0.6):
    """Build BM25 index with improved error handling"""
    print("Creating BM25 index...")
    try:
        with Timer("BM25 preprocessing"):
            tokenized_corpus = [preprocess_text(text) for text in tqdm(corpus_texts, desc="Preprocessing documents")]

        # Filter out empty documents to prevent BM25 errors
        empty_docs = [i for i, tokens in enumerate(tokenized_corpus) if not tokens]
        if empty_docs:
            print(f"Warning: Found {len(empty_docs)} empty documents after preprocessing")
            # Add at least one token to empty documents to prevent BM25 errors
            for i in empty_docs:
                tokenized_corpus[i] = ["_empty_"]

        with Timer("BM25 index construction"):
            bm25 = rank_bm25.BM25Okapi(tokenized_corpus, k1=k1, b=b)

        # Create BM25 related information
        avg_doc_len = sum(len(doc) for doc in tokenized_corpus) / len(tokenized_corpus)
        bm25_info = {
            "corpus_size": len(tokenized_corpus),
            "avg_doc_len": avg_doc_len,
            "idf_avg": sum(bm25.idf.values()) / len(bm25.idf) if bm25.idf else 0,
            "k1": k1,
            "b": b,
            "empty_docs": len(empty_docs)
        }

        return bm25, bm25_info, tokenized_corpus

    except Exception as e:
        print(f"Error building BM25 index: {str(e)}")
        # Print the error details
        import traceback
        traceback.print_exc() # Print the full traceback

        # Return a simple BM25 index with dummy data in case of error
        dummy_corpus = [["dummy"] for _ in range(len(corpus_texts))]
        dummy_bm25 = rank_bm25.BM25Okapi(dummy_corpus)
        dummy_info = {"error": str(e)}
        return dummy_bm25, dummy_info, dummy_corpus

# Build BM25 index
with Timer("Total BM25 index building"):
    bm25, bm25_info, tokenized_corpus = build_bm25_index(corpus_texts)

# Check if BM25 index creation was successful before printing average document length
if 'avg_doc_len' in bm25_info:
    print("BM25 index created successfully")
    print(f"Average document length: {bm25_info['avg_doc_len']:.2f} tokens")
else:
    print("BM25 index creation failed. Check the error message above.")
    # You might want to exit or handle the error here

# Save BM25 info (only if successful)
if 'avg_doc_len' in bm25_info:
    with open(os.path.join("models/indexes", "bm25_info.json"), 'w') as f:
        json.dump(bm25_info, f)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


Creating BM25 index...


Preprocessing documents: 100%|██████████| 41070/41070 [00:25<00:00, 1605.70it/s]


BM25 preprocessing completed in 25.58 seconds
BM25 index construction completed in 0.59 seconds
Total BM25 index building completed in 26.18 seconds
BM25 index created successfully
Average document length: 40.22 tokens


In [9]:
# Initialize all_indexes and all_configs dictionaries
all_indexes = {}
all_configs = {}

# Build FAISS indexes for each model
for model_name, model in models.items():
    print(f"\nBuilding FAISS indexes for {model_name}...")
    with Timer(f"FAISS index building for {model_name}"):
        # Get embeddings
        embeddings = model.encode(corpus_texts, show_progress_bar=True)
        embeddings = embeddings.astype('float32')  # Convert to float32
        faiss.normalize_L2(embeddings)  # Normalize embeddings

        # Build indexes
        index_config = {
            "dimension": dimensions[model_name],
            "nlist": 100,  # Number of clusters for IVF
            "m": 4,  # Number of bytes per vector for PQ
            "hnsw_m": 16  # Number of neighbors for HNSW
        }
        all_configs[model_name] = index_config  # Store index configuration

        # Create and store different index types
        all_indexes[model_name] = {
            "flat": faiss.IndexFlatL2(dimensions[model_name]),
            "hnsw": faiss.IndexHNSWFlat(dimensions[model_name], index_config["hnsw_m"]),
            "ivfpq": faiss.IndexIVFPQ(
                faiss.IndexFlatL2(dimensions[model_name]),
                dimensions[model_name],
                index_config["nlist"],
                index_config["m"],
                8  # Number of bits per code
            )
        }

        # Train and add data to indexes
        all_indexes[model_name]["flat"].add(embeddings)
        all_indexes[model_name]["hnsw"].add(embeddings)
        all_indexes[model_name]["ivfpq"].train(embeddings)
        all_indexes[model_name]["ivfpq"].add(embeddings)

        print(f"FAISS indexes built and populated for {model_name}")


Building FAISS indexes for msmarco_stsb...


Batches:   0%|          | 0/1284 [00:00<?, ?it/s]

FAISS indexes built and populated for msmarco_stsb
FAISS index building for msmarco_stsb completed in 66.82 seconds

Building FAISS indexes for stsb...


Batches:   0%|          | 0/1284 [00:00<?, ?it/s]

FAISS indexes built and populated for stsb
FAISS index building for stsb completed in 66.64 seconds


In [10]:
# 6. Save indexes and corpus information
print("Saving all model indexes...")
for model_name in models:
    if model_name not in all_indexes:
        print(f"Skipping {model_name} - no indexes built")
        continue

    model_dir = os.path.join(PROJECT_PATH, "models/indexes", model_name)
    os.makedirs(model_dir, exist_ok=True)

    model_indexes = all_indexes[model_name]

    # Save all index types
    print(f"\nSaving {model_name} indexes...")

    try:
        print(f"Saving {model_name} Flat index...")
        faiss.write_index(model_indexes["flat"], os.path.join(model_dir, "flat_index.faiss"))

        print(f"Saving {model_name} HNSW index...")
        faiss.write_index(model_indexes["hnsw"], os.path.join(model_dir, "hnsw_index.faiss"))

        print(f"Saving {model_name} IVF-PQ index...")
        faiss.write_index(model_indexes["ivfpq"], os.path.join(model_dir, "ivfpq_index.faiss"))

        # Save index configuration information
        with open(os.path.join(model_dir, "index_config.json"), 'w') as f:
            json.dump(all_configs[model_name], f)

        print(f"{model_name} indexes successfully saved to: {model_dir}")
    except Exception as e:
        print(f"Error saving indexes for {model_name}: {str(e)}")

# Save the corpus and document IDs
corpus_dir = os.path.join(PROJECT_PATH, "models/indexes")
print("\nSaving corpus data...")

try:
    # Save corpus
    with open(os.path.join(corpus_dir, "corpus.json"), 'w') as f:
        json.dump(corpus, f)

    # Save doc IDs
    with open(os.path.join(corpus_dir, "doc_ids.json"), 'w') as f:
        json.dump(doc_ids, f)

    # Save BM25 model using pickle
    print("Saving BM25 model...")
    with open(os.path.join(corpus_dir, "bm25_model.pkl"), 'wb') as f:
        pickle.dump(bm25, f)

    # Also save tokenized corpus for future use
    with open(os.path.join(corpus_dir, "tokenized_corpus.pkl"), 'wb') as f:
        pickle.dump(tokenized_corpus, f)

except Exception as e:
    print(f"Error saving corpus data: {str(e)}")

print("\nAll indexes and data successfully saved")

Saving all model indexes...

Saving msmarco_stsb indexes...
Saving msmarco_stsb Flat index...
Saving msmarco_stsb HNSW index...
Saving msmarco_stsb IVF-PQ index...
msmarco_stsb indexes successfully saved to: /content/drive/MyDrive/CS6120_project/models/indexes/msmarco_stsb

Saving stsb indexes...
Saving stsb Flat index...
Saving stsb HNSW index...
Saving stsb IVF-PQ index...
stsb indexes successfully saved to: /content/drive/MyDrive/CS6120_project/models/indexes/stsb

Saving corpus data...
Saving BM25 model...

All indexes and data successfully saved


In [11]:
# 7. Enhanced Retrieval Functions

def bm25_retrieve(query: str, bm25_index, doc_ids: List[str], k: int = 10) -> Tuple[List[str], List[float]]:
    """Retrieve documents using BM25 with improved error handling

    Args:
        query: Query string
        bm25_index: BM25 index
        doc_ids: List of document IDs
        k: Number of results to return

    Returns:
        top_doc_ids: List of document IDs
        top_scores: List of corresponding BM25 scores
    """
    try:
        # Process query
        query_tokens = preprocess_text(query)
        if not query_tokens:
            print(f"Warning: Empty query tokens for query: '{query}'")
            query_tokens = ["_empty_"]

        # Get BM25 scores
        bm25_scores = bm25_index.get_scores(query_tokens)

        # Ensure length compatibility
        if len(bm25_scores) != len(doc_ids):
            print(f"Warning: BM25 scores count ({len(bm25_scores)}) does not match doc_ids count ({len(doc_ids)})")
            if len(bm25_scores) < len(doc_ids):
                # Extend with zeros
                bm25_scores = np.concatenate([bm25_scores, np.zeros(len(doc_ids) - len(bm25_scores))])
            else:
                # Truncate
                bm25_scores = bm25_scores[:len(doc_ids)]

        # Get top k results
        if len(bm25_scores) == 0:
            print("Warning: No BM25 scores available")
            return [], []

        # Sort by score in descending order
        top_indices = np.argsort(bm25_scores)[::-1][:k]
        top_scores = bm25_scores[top_indices]

        # Map indices to document IDs
        top_doc_ids = [doc_ids[idx] for idx in top_indices]

        return top_doc_ids, top_scores

    except Exception as e:
        print(f"Error in BM25 retrieval: {str(e)}")
        # Return empty results in case of error
        return [], []


# Query classifier for dynamic weighting
def train_query_classifier():
    """Train a simple query classifier for dynamic strategy selection"""
    # Sample queries for training
    test_queries = [
        "How does social media affect mental health?",
        "Best programming languages to learn",
        "Artificial intelligence applications",
        "Climate change solutions and mitigation strategies",
        "Nutrition advice for athletes performance"
    ]

    # Feature extraction function
    def extract_query_features(query):
        features = []
        # Length features
        features.append(len(query))  # Raw character length
        features.append(len(query.split()))  # Word count

        # Question features
        features.append(1 if "?" in query else 0)  # Is it a question
        features.append(1 if query.lower().startswith("how") else 0)  # How question
        features.append(1 if query.lower().startswith("what") else 0)  # What question
        features.append(1 if query.lower().startswith("why") else 0)  # Why question

        # Structure features
        features.append(1 if "," in query else 0)  # Has comma
        features.append(1 if ":" in query else 0)  # Has colon

        # Topic features (simple keyword checking)
        features.append(1 if any(word in query.lower() for word in ["programming", "code", "software"]) else 0)  # Tech
        features.append(1 if any(word in query.lower() for word in ["health", "medical", "disease"]) else 0)  # Health

        return features

    # Build training features
    X_train = np.array([extract_query_features(q) for q in test_queries])
    # Dummy labels (which model is better)
    # 0 for primary model (semantic-heavy queries)
    # 1 for fallback model (simpler queries that might work better with BM25)
    y_train = np.array([0, 1, 0, 0, 1])  # Simulated labels

    # Train a simple query classifier
    classifier = LogisticRegression(random_state=42)
    classifier.fit(X_train, y_train)
    print("Query classifier trained")

    # Add feature extractor to classifier object for convenience
    classifier.extract_features = extract_query_features

    return classifier

# Train the classifier
query_classifier = train_query_classifier()

Query classifier trained


In [12]:
# 8. Main Enhanced Hybrid Retrieval Function
def hybrid_retrieve_documents(
    query: str,
    query_id: str,
    top_k: int = 100,
    strategy: str = "sbert_bm25",
    alpha: float = 0.7,
    debug: bool = False
) -> Dict[str, Dict[str, float]]:
    """Enhanced hybrid retrieval function with improved error handling

    Args:
        query: Query string
        query_id: Query ID
        top_k: Number of results to return
        strategy: Retrieval strategy
        alpha: Weight for semantic score (1-alpha for BM25)
        debug: Whether to print debug information

    Returns:
        Dictionary {query_id: {doc_id: score}} for MRR evaluation
    """
    # Validate inputs
    if not query:
        print("Error: Empty query")
        return {query_id: {}}  # Return empty results

    if strategy not in ["single", "dynamic", "fallback", "ensemble", "sbert_bm25"]:
        print(f"Error: Unknown strategy '{strategy}', falling back to 'single' strategy")
        strategy = "single"

    # Initialize result format for MRR evaluation
    result = {query_id: {}}

    # Start retrieval process with full error handling
    try:
        # Initialize model list based on strategy
        if strategy in ["single", "fallback"]:
            model_list = [primary_model]
        else:
            model_list = [primary_model, fallback_model]

        # Encode query for each model
        query_embeddings = {}
        for model_name in model_list:
            if model_name not in models:
                print(f"Error: Model {model_name} not loaded")
                continue

            query_emb = models[model_name].encode([query])
            faiss.normalize_L2(query_emb)
            query_embeddings[model_name] = query_emb

        # If no query embeddings could be generated, return empty results
        if not query_embeddings:
            print("Error: No query embeddings could be generated")
            return {query_id: {}}

        # Process based on strategy
        if strategy == "single":
            # Just use primary model
            if primary_model not in all_indexes:
                print(f"Error: No indexes for {primary_model}")
                return {query_id: {}}

            index = all_indexes[primary_model]["hnsw"]
            D, I = index.search(query_embeddings[primary_model], top_k)

            # Build results in required format
            for i in range(min(top_k, len(I[0]))):
                if I[0][i] >= 0 and I[0][i] < len(doc_ids):  # Ensure valid index
                    doc_id = doc_ids[I[0][i]]
                    score = float(D[0][i])
                    result[query_id][doc_id] = score

        elif strategy == "dynamic":
            # Use query classifier to determine best model
            features = query_classifier.extract_features(query)
            features = np.array([features])

            # Predict which model to use
            model_idx = query_classifier.predict(features)[0]
            model_to_use = primary_model if model_idx == 0 else fallback_model

            if debug:
                print(f"Dynamic strategy selected model: {model_to_use}")

            if model_to_use not in all_indexes or model_to_use not in query_embeddings:
                # Fall back to primary model if selected model is not available
                print(f"Warning: Selected model {model_to_use} not available, using {primary_model} instead")
                model_to_use = primary_model

            # Use the selected model
            index = all_indexes[model_to_use]["hnsw"]
            D, I = index.search(query_embeddings[model_to_use], top_k)

            # Build results in required format
            for i in range(min(top_k, len(I[0]))):
                if I[0][i] >= 0 and I[0][i] < len(doc_ids):
                    doc_id = doc_ids[I[0][i]]
                    score = float(D[0][i])
                    result[query_id][doc_id] = score

        elif strategy == "fallback":
            # First try with primary model
            primary_index = all_indexes[primary_model]["hnsw"]
            D_primary, I_primary = primary_index.search(query_embeddings[primary_model], top_k)

            # Check confidence (average similarity score)
            confidence = np.mean(D_primary[0]) if len(D_primary[0]) > 0 else 0
            threshold = 0.3  # Confidence threshold

            if debug:
                print(f"Primary model confidence: {confidence:.4f} (threshold: {threshold})")

            # Decide which results to use
            if confidence > threshold:
                # Use primary model results
                D, I = D_primary, I_primary
                if debug:
                    print("Using primary model results")
            else:
                # Switch to fallback model
                if debug:
                    print(f"Switching to fallback model ({fallback_model})")

                fallback_index = all_indexes[fallback_model]["hnsw"]
                D, I = fallback_index.search(query_embeddings[fallback_model], top_k)

            # Build results in required format
            for i in range(min(top_k, len(I[0]))):
                if I[0][i] >= 0 and I[0][i] < len(doc_ids):
                    doc_id = doc_ids[I[0][i]]
                    score = float(D[0][i])
                    result[query_id][doc_id] = score

        elif strategy == "ensemble":
            # Get results from each model
            all_results = {}

            for model_name in model_list:
                if model_name not in all_indexes or model_name not in query_embeddings:
                    continue

                model_index = all_indexes[model_name]["hnsw"]
                D, I = model_index.search(query_embeddings[model_name], top_k * 2)  # Get more candidates

                # Save score for each document ID
                for j in range(len(I[0])):
                    idx = int(I[0][j])
                    if idx < 0 or idx >= len(doc_ids):  # Skip invalid indices
                        continue

                    doc_id = doc_ids[idx]
                    score = float(D[0][j])

                    if doc_id not in all_results:
                        all_results[doc_id] = {}

                    all_results[doc_id][model_name] = score

            # Compute combined scores using weights
            weights = {primary_model: alpha, fallback_model: 1.0-alpha}
            final_scores = {}

            for doc_id in all_results:
                final_scores[doc_id] = 0
                for model_name, weight in weights.items():
                    if model_name in all_results[doc_id]:
                        final_scores[doc_id] += all_results[doc_id][model_name] * weight

            # Sort and select top_k results
            sorted_results = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]

            # Build results in required format
            for doc_id, score in sorted_results:
                result[query_id][doc_id] = float(score)

        elif strategy == "sbert_bm25":
            # SBERT + BM25 hybrid approach - most thorough option
            if debug:
                print(f"Running sbert_bm25 strategy with alpha={alpha:.2f}")

            # 1. Get SBERT results
            sbert_index = all_indexes[primary_model]["hnsw"]
            D_sbert, I_sbert = sbert_index.search(query_embeddings[primary_model], top_k*2)

            # Get SBERT results in dictionary format
            sbert_results = {}
            for j in range(len(I_sbert[0])):
                idx = int(I_sbert[0][j])
                if idx < 0 or idx >= len(doc_ids):
                    continue

                doc_id = doc_ids[idx]
                score = float(D_sbert[0][j])
                sbert_results[doc_id] = score

            # 2. Get BM25 results
            bm25_doc_ids, bm25_scores = bm25_retrieve(query, bm25, doc_ids, k=top_k*2)

            # Get BM25 results in dictionary format
            bm25_results = {}
            for doc_id, score in zip(bm25_doc_ids, bm25_scores):
                bm25_results[doc_id] = float(score)

            if debug:
                print(f"SBERT found {len(sbert_results)} results")
                print(f"BM25 found {len(bm25_results)} results")

            # 3. Check if we have results
            if not sbert_results and not bm25_results:
                print("Warning: No results found from either SBERT or BM25")
                return {query_id: {}}

            # 4. Normalize scores
            if sbert_results and len(sbert_results) > 0:
                sbert_docs = list(sbert_results.keys())
                sbert_scores = np.array(list(sbert_results.values()))
                sbert_scores_norm = normalize_scores(sbert_scores)
                sbert_results = dict(zip(sbert_docs, sbert_scores_norm))

            if bm25_results and len(bm25_results) > 0:
                bm25_docs = list(bm25_results.keys())
                bm25_scores = np.array(list(bm25_results.values()))
                bm25_scores_norm = normalize_scores(bm25_scores)
                bm25_results = dict(zip(bm25_docs, bm25_scores_norm))

            # 5. Combine unique candidates
            all_candidates = set(sbert_results.keys()) | set(bm25_results.keys())

            if debug:
                print(f"Total unique candidates: {len(all_candidates)}")

            # 6. Calculate combined scores
            combined_scores = {}
            for doc_id in all_candidates:
                sbert_score = sbert_results.get(doc_id, 0.0)
                bm25_score = bm25_results.get(doc_id, 0.0)
                # Apply scaling factor and combine
                combined_scores[doc_id] = alpha * sbert_score + (1-alpha) * bm25_score

            # 7. Sort and take top k
            ranked_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]

            # 8. Build results in required format
            for doc_id, score in ranked_results:
                result[query_id][doc_id] = float(score)

    except Exception as e:
        print(f"Error in hybrid retrieval: {str(e)}")
        # Return empty results in case of error

    return result

In [13]:
# 9. Enhanced Evaluation Functions - 修改版

def compute_mrr_at_k(run: Dict[str, Dict[str, float]],
                     qrels: Dict[str, Dict[str, int]],
                     k: int = 100,
                     verbose: bool = False) -> float:
    """Compute MRR@k with improved error handling

    Args:
        run: Search results {query_id: {doc_id: score}}
        qrels: Ground truth {query_id: {doc_id: relevance}}
        k: Cutoff value
        verbose: Whether to print detailed information

    Returns:
        MRR value
    """
    total_rr = 0.0
    num_queries = 0
    queries_with_relevant_docs = 0
    queries_with_relevant_retrieved = 0

    for qid, relevant_docs in qrels.items():
        # Skip if no relevant docs
        if not relevant_docs:
            continue

        # Count this as a query with relevant docs
        queries_with_relevant_docs += 1

        # Skip if query not in run
        if qid not in run:
            if verbose:
                print(f"Query {qid} not in run")
            continue

        # Skip if run has no results for this query
        if not run[qid]:
            if verbose:
                print(f"No results for query {qid}")
            continue

        # Sort results by score in descending order, take top k
        try:
            sorted_docs = sorted(run[qid].items(), key=lambda x: x[1], reverse=True)[:k]
        except Exception as e:
            print(f"Error sorting results for query {qid}: {str(e)}")
            print(f"Run[qid] type: {type(run[qid])}, value: {run[qid]}")
            continue

        rr = 0.0  # Reciprocal rank for current query
        for rank, (doc_id, score) in enumerate(sorted_docs, start=1):
            # If document is relevant
            if doc_id in relevant_docs and relevant_docs[doc_id] > 0:
                rr = 1.0 / rank
                queries_with_relevant_retrieved += 1
                break  # Only consider first relevant document

        total_rr += rr
        num_queries += 1

    if verbose:
        print(f"Queries processed: {num_queries}")
        print(f"Queries with relevant docs: {queries_with_relevant_docs}")
        print(f"Queries with relevant docs retrieved: {queries_with_relevant_retrieved}")

    if num_queries == 0:
        print("Warning: No queries with results were evaluated")
        return 0.0

    return total_rr / num_queries


def compute_recall_at_k(run: Dict[str, Dict[str, float]],
                       qrels: Dict[str, Dict[str, int]],
                       k: int = 100,
                       verbose: bool = False) -> float:
    """Compute Recall@K with improved error handling

    Args:
        run: Search results {query_id: {doc_id: score}}
        qrels: Ground truth {query_id: {doc_id: relevance}}
        k: Cutoff value
        verbose: Whether to print detailed information

    Returns:
        Recall value
    """
    total_recall = 0.0
    num_queries_with_rels = 0  # Only count queries with relevant documents
    recall_values = []

    for qid, rel_docs in qrels.items():
        # Get relevant document set
        relevant_docs = {doc_id for doc_id, rel in rel_docs.items() if rel > 0}
        if not relevant_docs:
            # Skip queries without relevant documents
            continue

        # Count this query
        num_queries_with_rels += 1

        # Skip if query not in run
        if qid not in run:
            recall_values.append(0.0)
            continue

        # Skip if run has no results for this query
        if not run[qid]:
            recall_values.append(0.0)
            continue

        try:
            # Get top k documents by score
            top_docs = sorted(run[qid].items(), key=lambda x: x[1], reverse=True)[:k]
            top_docs_ids = {doc_id for doc_id, score in top_docs}

            # Compute recall: hits / total relevant
            hit_count = len(relevant_docs & top_docs_ids)
            recall_q = hit_count / len(relevant_docs)

            recall_values.append(recall_q)
            total_recall += recall_q

            if verbose and recall_q == 1.0:
                print(f"Query {qid} has perfect recall")

        except Exception as e:
            print(f"Error computing recall for query {qid}: {str(e)}")
            recall_values.append(0.0)

    if verbose:
        print(f"Queries with relevant docs: {num_queries_with_rels}")
        if recall_values:
            print(f"Min recall: {min(recall_values):.4f}")
            print(f"Max recall: {max(recall_values):.4f}")
            print(f"Median recall: {np.median(recall_values):.4f}")

    if num_queries_with_rels == 0:
        print("Warning: No queries with relevant documents were evaluated")
        return 0.0

    return total_recall / num_queries_with_rels


# Check if pytrec_eval is available and define NDCG function
try:
    import pytrec_eval

    def compute_ndcg_at_k(run, qrels, k=100, verbose=False):
        """Compute NDCG@k using pytrec_eval"""
        try:
            # Create evaluator
            evaluator = pytrec_eval.RelevanceEvaluator(qrels, {'ndcg'})

            # Evaluate
            results = evaluator.evaluate(run)

            # Extract NDCG scores
            ndcg_values = [results[qid].get('ndcg', 0.0) for qid in results]

            if verbose:
                print(f"Queries evaluated for NDCG: {len(ndcg_values)}")
                if ndcg_values:
                    print(f"Min NDCG: {min(ndcg_values):.4f}")
                    print(f"Max NDCG: {max(ndcg_values):.4f}")
                    print(f"Median NDCG: {np.median(ndcg_values):.4f}")

            # Calculate mean NDCG
            if not ndcg_values:
                return 0.0

            return sum(ndcg_values) / len(ndcg_values)
        except Exception as e:
            print(f"Error computing NDCG: {str(e)}")
            return 0.0

    has_pytrec_eval = True
    print("pytrec_eval loaded successfully for NDCG calculation")

except ImportError:
    has_pytrec_eval = False
    print("pytrec_eval not available, skipping NDCG calculation")

pytrec_eval loaded successfully for NDCG calculation


In [16]:
def test_sample_query(strategy="sbert_bm25", alpha=0.0, top_k=5):
    """Fetches a sample query and runs the hybrid search."""
    sample_qid = random.choice(list(queries.keys()))
    sample_query = queries[sample_qid]
    print(f"Running sample query (ID: {sample_qid}): {sample_query}")
    results = hybrid_retrieve_documents(
        sample_query, sample_qid, top_k=top_k, strategy=strategy, alpha=alpha
    )
    print(f"Results: {results}")
    return results

In [17]:
# 测试检索功能：运行这个单元格来获取输出结果
# 测试单个查询示例
print("Testing a sample query...")
sample_result = test_sample_query(strategy="sbert_bm25", alpha=0.0, top_k=5)

# 测试BM25检索功能
print("\n--- Testing BM25 Retrieval ---")
sample_qid = random.choice(list(queries.keys()))
sample_query = queries[sample_qid]
print(f"Query: {sample_query}")

bm25_doc_ids, bm25_scores = bm25_retrieve(sample_query, bm25, doc_ids, k=5)
print("\nBM25 retrieval results:")
for i, (doc_id, score) in enumerate(zip(bm25_doc_ids, bm25_scores)):
    print(f"Result {i+1}: {doc_id} (Score: {score:.4f})")
    print(f"Text: {corpus[doc_id][:100]}...")

# 验证MRR计算
print("\n--- Verifying MRR Calculation ---")
# 创建一个小型测试集
test_qid = sample_qid
test_query = sample_query
test_run = {}
test_run[test_qid] = {}

# 添加一些检索结果
test_results = hybrid_retrieve_documents(
    test_query, test_qid, top_k=10, strategy="sbert_bm25", alpha=0.0
)

print(f"Run format for MRR calculation: {type(test_results)}")
print(f"Example run entry: {list(test_results.items())[0]}")

# 计算MRR
if test_qid in qrels:
    test_mrr = compute_mrr_at_k({test_qid: test_results[test_qid]}, {test_qid: qrels[test_qid]}, k=10, verbose=True)
    print(f"Test MRR@10: {test_mrr:.4f}")
else:
    print(f"Query {test_qid} has no relevance judgments")

    # 找一个有相关文档的查询
    for alt_qid in qrels:
        if qrels[alt_qid]:
            print(f"Using alternative query {alt_qid} with {len(qrels[alt_qid])} relevant docs")
            alt_query = queries[alt_qid]
            alt_results = hybrid_retrieve_documents(
                alt_query, alt_qid, top_k=10, strategy="sbert_bm25", alpha=0.0
            )
            alt_mrr = compute_mrr_at_k(alt_results, {alt_qid: qrels[alt_qid]}, k=10, verbose=True)
            print(f"Alternative test MRR@10: {alt_mrr:.4f}")
            break

Testing a sample query...
Running sample query (ID: 19426): how did the Athenians differ from the spartans in their views on education and military training according to pericles
Results: {'19426': {'19426_3': 0.9999999993891144, '19426_8': 0.7346147976930142, '19426_1': 0.6886383514940043, '19426_4': 0.6272955878640288, '12343_1': 0.23965778986830608}}

--- Testing BM25 Retrieval ---
Query: calories in one pound of smoked turkey

BM25 retrieval results:
Result 1: 19188_4 (Score: 24.8400)
Text: 1 approx 608 calories in 1 pound of roast turkey breast, no skin. 2  approx 848 calories in 1 pound ...
Result 2: 19188_0 (Score: 24.7318)
Text: Calories in 1 pound of turkey breast. 1  approx 608 calories in 1 pound of roast turkey breast, no s...
Result 3: 19188_3 (Score: 24.1206)
Text: 1 approx 848 calories in 1 pound of roast turkey breast including the skin. 2  approx 1072 calories ...
Result 4: 10592_6 (Score: 24.0112)
Text: How to Cook Pre-Smoked Turkey Drumsticks in the Oven. A reheated 

In [19]:
def evaluate_strategies(max_queries=100, k_values=[10, 50, 100], strategies=["single", "dynamic", "fallback", "ensemble", "sbert_bm25"]):
    """Evaluates different retrieval strategies using MRR, Recall, and NDCG."""
    results = {}
    for strategy in strategies:
        results[strategy] = {}
        for k in k_values:
            print(f"\nEvaluating strategy: {strategy}, k={k}")
            run = {}
            queries_subset = random.sample(list(queries.keys()), min(max_queries, len(queries)))
            for qid in tqdm(queries_subset, desc="Processing queries"):
                query_text = queries[qid]
                results_for_query = hybrid_retrieve_documents(query_text, qid, top_k=k, strategy=strategy)
                run.update(results_for_query)

            mrr = compute_mrr_at_k(run, qrels, k=k)
            recall = compute_recall_at_k(run, qrels, k=k)
            results[strategy][k] = {"mrr": mrr, "recall": recall}

            if has_pytrec_eval:
                ndcg = compute_ndcg_at_k(run, qrels, k=k)
                results[strategy][k]["ndcg"] = ndcg

    return results

In [20]:
# 评估不同检索策略（小样本）
print("Evaluating retrieval strategies on a small sample...")
evaluation_results = evaluate_strategies(
    max_queries=5,  # 使用较小的样本量便于快速运行
    k_values=[10, 50, 100],
    strategies=["single", "sbert_bm25"]  # 只评估两种最重要的策略
)

# 显示结果摘要
print("\nResults summary:")
for strategy, results in evaluation_results.items():
    print(f"\nStrategy: {strategy}")
    for k, metrics in results.items():
        print(f"  k={k}: MRR={metrics['mrr']:.4f}, Recall={metrics['recall']:.4f}" +
              (f", NDCG={metrics['ndcg']:.4f}" if 'ndcg' in metrics and metrics['ndcg'] > 0 else ""))

Evaluating retrieval strategies on a small sample...

Evaluating strategy: single, k=10


Processing queries: 100%|██████████| 5/5 [00:00<00:00, 72.45it/s]



Evaluating strategy: single, k=50


Processing queries: 100%|██████████| 5/5 [00:00<00:00, 77.19it/s]



Evaluating strategy: single, k=100


Processing queries: 100%|██████████| 5/5 [00:00<00:00, 71.30it/s]



Evaluating strategy: sbert_bm25, k=10


Processing queries: 100%|██████████| 5/5 [00:00<00:00, 14.83it/s]



Evaluating strategy: sbert_bm25, k=50


Processing queries: 100%|██████████| 5/5 [00:00<00:00, 12.86it/s]



Evaluating strategy: sbert_bm25, k=100


Processing queries: 100%|██████████| 5/5 [00:00<00:00, 13.91it/s]


Results summary:

Strategy: single
  k=10: MRR=0.1131, Recall=0.0006, NDCG=0.2557
  k=50: MRR=0.0210, Recall=0.0010, NDCG=0.1785
  k=100: MRR=0.0116, Recall=0.0010, NDCG=0.1512

Strategy: sbert_bm25
  k=10: MRR=0.0000, Recall=0.0000
  k=50: MRR=0.0000, Recall=0.0000
  k=100: MRR=0.0000, Recall=0.0000





In [22]:
def tune_alpha_parameter(sample_size=10, alpha_values=np.arange(0, 1.1, 0.1), k=10):
    """Tunes the alpha parameter for the sbert_bm25 strategy."""
    results = {}
    queries_subset = random.sample(list(queries.keys()), min(sample_size, len(queries)))

    for alpha in alpha_values:
        print(f"\nEvaluating alpha={alpha:.1f}")
        run = {}
        for qid in tqdm(queries_subset, desc="Processing queries"):
            query_text = queries[qid]
            results_for_query = hybrid_retrieve_documents(
                query_text, qid, top_k=k, strategy="sbert_bm25", alpha=alpha
            )
            run.update(results_for_query)

        mrr = compute_mrr_at_k(run, qrels, k=k)
        results[alpha] = mrr

    best_alpha = max(results, key=results.get)
    print(f"\nBest alpha: {best_alpha:.1f} (MRR: {results[best_alpha]:.4f})")
    return results

In [23]:
# 测试alpha参数调优（使用非常小的样本以便快速运行）
print("Testing alpha parameter tuning on a very small sample...")
mini_alpha_results = tune_alpha_parameter(sample_size=3)

Testing alpha parameter tuning on a very small sample...

Evaluating alpha=0.0


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 12.62it/s]



Evaluating alpha=0.1


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 13.07it/s]



Evaluating alpha=0.2


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 13.60it/s]



Evaluating alpha=0.3


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 12.49it/s]



Evaluating alpha=0.4


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 11.87it/s]



Evaluating alpha=0.5


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 12.98it/s]



Evaluating alpha=0.6


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 13.57it/s]



Evaluating alpha=0.7


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 13.42it/s]



Evaluating alpha=0.8


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 13.61it/s]



Evaluating alpha=0.9


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 13.05it/s]



Evaluating alpha=1.0


Processing queries: 100%|██████████| 3/3 [00:00<00:00, 13.78it/s]


Best alpha: 0.3 (MRR: 0.2222)



