<a href="https://colab.research.google.com/github/Rongxuan-Zhou/CS6120_project/blob/index_construction-%26-hybrid_retrieval/notebooks/index-construction%20%26%20hybrid-retrieval-2.0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
# 1. Environment Setup
!pip install -q faiss-cpu sentence-transformers nltk rank-bm25 hnswlib scikit-learn datasets pytrec_eval tqdm
from google.colab import drive
drive.mount('/content/drive')

import os
import sys
import time
import json
import pickle
import numpy as np
import torch
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
import rank_bm25
from collections import Counter
import random
from typing import Dict, List, Tuple, Any, Optional, Union

# Set project path
PROJECT_PATH = "/content/drive/MyDrive/CS6120_project"
os.chdir(PROJECT_PATH)

# Create necessary directories
os.makedirs("models/indexes", exist_ok=True)

# Check GPU availability
print(f"Available GPU: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU for processing")

# Download NLTK resources
import nltk
for resource in ['punkt', 'stopwords']:
    try:
        nltk.data.find(f'tokenizers/{resource}')
    except LookupError:
        nltk.download(resource)

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Available GPU: True
Using GPU: NVIDIA A100-SXM4-40GB


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [20]:
# 2. Helper Classes & Functions

class Timer:
    """Simple timer for benchmarking"""
    def __init__(self, name="Operation"):
        self.name = name

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.end = time.time()
        self.interval = self.end - self.start
        print(f"{self.name} completed in {self.interval:.2f} seconds")


def preprocess_text(text: str) -> List[str]:
    """Preprocess text for BM25 indexing"""
    if not isinstance(text, str):
        return []

    # Tokenize
    tokens = word_tokenize(text.lower())

    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [token for token in tokens if token not in stop_words and token.isalnum()]

    return tokens


def normalize_scores(scores: np.ndarray) -> np.ndarray:
    """Normalize scores to [0,1] range with safe handling of edge cases"""
    if len(scores) == 0:
        return scores

    min_val = np.min(scores)
    max_val = np.max(scores)

    if max_val == min_val:
        return np.ones_like(scores)

    return (scores - min_val) / (max_val - min_val + 1e-8)


def load_models(model_paths: Dict[str, str]) -> Tuple[Dict[str, SentenceTransformer], Dict[str, int]]:
    """Load all models and return them with their embedding dimensions"""
    models = {}
    dimensions = {}

    for model_name, model_path in model_paths.items():
        print(f"Loading model: {model_name}")
        try:
            models[model_name] = SentenceTransformer(model_path)
            models[model_name].to(device)
            dimensions[model_name] = models[model_name].get_sentence_embedding_dimension()

            print(f"  - Model path: {model_path}")
            print(f"  - Model architecture: {dimensions[model_name]}d embedding dimension")
            print(f"  - Model details: {models[model_name]}")
            print("")
        except Exception as e:
            print(f"Error loading model {model_name} from {model_path}: {str(e)}")
            continue

    if not models:
        raise ValueError("No models could be loaded. Please check model paths.")

    return models, dimensions

In [21]:
# 3. MS MARCO Data Loading
from datasets import load_dataset

def load_msmarco_data(max_samples: int = 5000, seed: int = 42) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, Dict[str, int]]]:
    """Load MS MARCO dataset with proper format for MRR evaluation

    Args:
        max_samples: Maximum samples to load
        seed: Random seed

    Returns:
        corpus: Dictionary {doc_id: document_text}
        queries: Dictionary {query_id: query_text}
        qrels: Dictionary {query_id: {doc_id: relevance}}
    """
    print("Loading MS MARCO dataset...")
    try:
        dataset = load_dataset("ms_marco", "v1.1")
        dev_data = dataset["validation"].shuffle(seed=seed).select(range(max_samples))
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        raise

    queries = {}
    corpus = {}
    qrels = {}

    # Process each sample
    for example in dev_data:
        # Get query_id and query text
        qid = str(example["query_id"])
        query_text = example["query"]
        queries[qid] = query_text

        # Get passages information
        passages_info = example["passages"]
        passage_texts = passages_info.get("passage_text", [])
        is_selecteds = passages_info.get("is_selected", [])

        # Process each passage
        for i, (text, is_sel) in enumerate(zip(passage_texts, is_selecteds)):
            # Generate unique document ID as "qid_i"
            doc_id = f"{qid}_{i}"
            corpus[doc_id] = text

            # If passage is relevant, add to qrels
            if is_sel == 1:
                if qid not in qrels:
                    qrels[qid] = {}
                qrels[qid][doc_id] = 1

    # Check positive counts
    check_positive_counts(queries, qrels)

    print(f"Loaded {len(corpus)} documents, {len(queries)} queries, {len(qrels)} qrels.")
    return corpus, queries, qrels


def check_positive_counts(queries: Dict[str, str], qrels: Dict[str, Dict[str, int]]):
    """Analyze distribution of relevant documents per query"""
    # Count positive examples per query
    positive_counts = []
    for qid in queries:
        if qid in qrels:
            positive_counts.append(len(qrels[qid]))
        else:
            positive_counts.append(0)

    # Count distribution
    counter = Counter(positive_counts)
    print("Positive examples distribution (count: queries):")
    for num_pos, num_queries in sorted(counter.items()):
        print(f"{num_pos} positive examples: {num_queries} queries")

    # Count queries without positives
    total_queries = len(queries)
    no_positive = counter.get(0, 0)
    print(f"\nTotal queries: {total_queries}")
    print(f"Queries without positives: {no_positive} ({no_positive/total_queries*100:.2f}%)")

In [22]:
# 4. Load Dataset and Models
# Load MS MARCO dataset
try:
    corpus, queries, qrels = load_msmarco_data(max_samples=5000, seed=42)

    # Extract important data structures
    corpus_texts = list(corpus.values())
    doc_ids = list(corpus.keys())

    # Define model paths
    model_paths = {
        "msmarco_stsb": os.path.join(PROJECT_PATH, "model/msmarco_stsb_finetuned_model"),
        "stsb": os.path.join(PROJECT_PATH, "model/stsb_finetuned_model")
    }

    # Load models
    models, dimensions = load_models(model_paths)

    # Models for primary retrieval and fallback
    primary_model = "msmarco_stsb"
    fallback_model = "stsb"

    print(f"Primary model: {primary_model}")
    print(f"Fallback model: {fallback_model}")

except Exception as e:
    print(f"Error during dataset or model loading: {str(e)}")
    print("Please check your environment setup and model paths.")

Loading MS MARCO dataset...
Positive examples distribution (count: queries):
0 positive examples: 160 queries
1 positive examples: 4371 queries
2 positive examples: 425 queries
3 positive examples: 38 queries
4 positive examples: 6 queries

Total queries: 5000
Queries without positives: 160 (3.20%)
Loaded 41070 documents, 5000 queries, 4840 qrels.
Loading model: msmarco_stsb
  - Model path: /content/drive/MyDrive/CS6120_project/model/msmarco_stsb_finetuned_model
  - Model architecture: 768d embedding dimension
  - Model details: SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

Loading model: stsb
  - Model path: /conten

In [46]:
# 5. Build BM25 and FAISS indexes
# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Create BM25 index with improved error handling
def build_bm25_index(corpus_texts, k1=0.9, b=0.6):
    """Build BM25 index with improved error handling"""
    print("Creating BM25 index...")
    try:
        with Timer("BM25 preprocessing"):
            tokenized_corpus = [preprocess_text(text) for text in tqdm(corpus_texts, desc="Preprocessing documents")]

        # Filter out empty documents to prevent BM25 errors
        empty_docs = [i for i, tokens in enumerate(tokenized_corpus) if not tokens]
        if empty_docs:
            print(f"Warning: Found {len(empty_docs)} empty documents after preprocessing")
            # Add at least one token to empty documents to prevent BM25 errors
            for i in empty_docs:
                tokenized_corpus[i] = ["_empty_"]

        with Timer("BM25 index construction"):
            bm25 = rank_bm25.BM25Okapi(tokenized_corpus, k1=k1, b=b)

        # Create BM25 related information
        avg_doc_len = sum(len(doc) for doc in tokenized_corpus) / len(tokenized_corpus)
        bm25_info = {
            "corpus_size": len(tokenized_corpus),
            "avg_doc_len": avg_doc_len,
            "idf_avg": sum(bm25.idf.values()) / len(bm25.idf) if bm25.idf else 0,
            "k1": k1,
            "b": b,
            "empty_docs": len(empty_docs)
        }

        return bm25, bm25_info, tokenized_corpus

    except Exception as e:
        print(f"Error building BM25 index: {str(e)}")
        # Return a simple BM25 index with dummy data in case of error
        dummy_corpus = [["dummy"] for _ in range(len(corpus_texts))]
        dummy_bm25 = rank_bm25.BM25Okapi(dummy_corpus)
        dummy_info = {"error": str(e)}
        return dummy_bm25, dummy_info, dummy_corpus

# Build BM25 index
with Timer("Total BM25 index building"):
    bm25, bm25_info, tokenized_corpus = build_bm25_index(corpus_texts)

print("BM25 index created successfully")
print(f"Average document length: {bm25_info['avg_doc_len']:.2f} tokens")

# Save BM25 info
with open(os.path.join("models/indexes", "bm25_info.json"), 'w') as f:
    json.dump(bm25_info, f)

# Create FAISS indexes for each model
def build_faiss_indexes(model, corpus_texts, model_name):
    """Build all FAISS indexes for a given model"""
    print(f"\nProcessing {model_name} model...")

    # Generate embeddings
    print(f"Generating {model_name} embeddings...")
    embeddings = None

    timer = Timer(f"{model_name} encoding")
    with timer:
        batch_size = 128
        embeddings_list = []

        for i in tqdm(range(0, len(corpus_texts), batch_size)):
            batch = corpus_texts[i:i+batch_size]
            try:
                batch_embeddings = model.encode(batch, show_progress_bar=False)
                embeddings_list.append(batch_embeddings)
            except Exception as e:
                print(f"Error encoding batch {i}-{i+batch_size}: {str(e)}")
                # Add zero embeddings for failed batch
                dimension = model.get_sentence_embedding_dimension()
                zero_embeddings = np.zeros((len(batch), dimension))
                embeddings_list.append(zero_embeddings)

        embeddings = np.vstack(embeddings_list)

    dimension = embeddings.shape[1]
    print(f"Generated {len(embeddings)} embeddings with dimension {dimension}")
    print(f"Processing speed: {len(corpus_texts)/timer.interval:.2f} docs/sec")

    # Normalize vectors for cosine similarity
    print(f"Normalizing {model_name} vectors...")
    faiss.normalize_L2(embeddings)

    # Create indexes
    model_indexes = {}

    # 1. Flat index (exact search)
    print(f"Building {model_name} Flat index...")
    index_flat = faiss.IndexFlatIP(dimension)
    index_flat.add(embeddings)
    print(f"{model_name} Flat index built with {index_flat.ntotal} vectors")
    model_indexes["flat"] = index_flat

    # 2. HNSW index (fast approximate search)
    print(f"Building {model_name} HNSW index...")
    M = 16  # Connections per node
    ef_construction = 200  # Search width during construction
    index_hnsw = faiss.IndexHNSWFlat(dimension, M)
    index_hnsw.hnsw.efConstruction = ef_construction
    index_hnsw.add(embeddings)
    print(f"{model_name} HNSW index built with {index_hnsw.ntotal} vectors")
    model_indexes["hnsw"] = index_hnsw

    # 3. IVF-PQ index (memory-efficient)
    print(f"Building {model_name} IVF-PQ index...")
    nlist = min(100, len(corpus_texts) // 50)  # Number of cluster centers
    m = 8  # Number of subvectors
    bits = 8  # Bits per subvector
    quantizer = faiss.IndexFlatL2(dimension)
    index_ivfpq = faiss.IndexIVFPQ(quantizer, dimension, nlist, m, bits)
    index_ivfpq.train(embeddings)
    index_ivfpq.add(embeddings)
    print(f"{model_name} IVF-PQ index built with {index_ivfpq.ntotal} vectors")
    model_indexes["ivfpq"] = index_ivfpq

    # Create index config
    index_config = {
        "model_name": model_name,
        "dimension": dimension,
        "flat_index": {"type": "IndexFlatIP", "dimension": dimension},
        "hnsw_index": {"type": "IndexHNSWFlat", "dimension": dimension, "M": M, "efConstruction": ef_construction},
        "ivfpq_index": {"type": "IndexIVFPQ", "dimension": dimension, "nlist": nlist, "m": m, "bits": bits, "recommended_nprobe": 30}
    }

    return {
        "embeddings": embeddings,
        "indexes": model_indexes,
        "config": index_config
    }

# Build FAISS indexes for each model
all_embeddings = {}
all_indexes = {}
all_configs = {}

for model_name, model in models.items():
    try:
        result = build_faiss_indexes(model, corpus_texts, model_name)
        all_embeddings[model_name] = result["embeddings"]
        all_indexes[model_name] = result["indexes"]
        all_configs[model_name] = result["config"]
    except Exception as e:
        print(f"Error building indexes for model {model_name}: {str(e)}")

# Save hybrid retrieval configuration
hybrid_config = {
    "primary_model": primary_model,
    "fallback_model": fallback_model,
    "corpus_size": len(corpus),
    "models": {
        model_name: {
            "dimension": dimensions[model_name],
            "index_types": list(all_indexes[model_name].keys()) if model_name in all_indexes else []
        } for model_name in models
    },
    "bm25_info": bm25_info
}

with open(os.path.join("models/indexes", "hybrid_config.json"), 'w') as f:
    json.dump(hybrid_config, f)

print("\nAll model indexes successfully built")

Creating BM25 index...



Preprocessing documents:   0%|          | 0/41070 [00:00<?, ?it/s][A
Preprocessing documents:   0%|          | 163/41070 [00:00<00:25, 1629.33it/s][A
Preprocessing documents:   1%|          | 326/41070 [00:00<00:26, 1549.44it/s][A
Preprocessing documents:   1%|          | 482/41070 [00:00<00:26, 1528.12it/s][A
Preprocessing documents:   2%|▏         | 644/41070 [00:00<00:25, 1563.20it/s][A
Preprocessing documents:   2%|▏         | 801/41070 [00:00<00:26, 1543.93it/s][A
Preprocessing documents:   2%|▏         | 956/41070 [00:00<00:26, 1536.81it/s][A
Preprocessing documents:   3%|▎         | 1110/41070 [00:00<00:26, 1535.26it/s][A
Preprocessing documents:   3%|▎         | 1279/41070 [00:00<00:25, 1583.28it/s][A
Preprocessing documents:   4%|▎         | 1440/41070 [00:00<00:24, 1589.01it/s][A
Preprocessing documents:   4%|▍         | 1606/41070 [00:01<00:24, 1609.49it/s][A
Preprocessing documents:   4%|▍         | 1774/41070 [00:01<00:24, 1630.11it/s][A
Preprocessing document

BM25 preprocessing completed in 25.41 seconds
BM25 index construction completed in 0.59 seconds
Total BM25 index building completed in 26.02 seconds
BM25 index created successfully
Average document length: 40.22 tokens

Processing msmarco_stsb model...
Generating msmarco_stsb embeddings...



  0%|          | 0/321 [00:00<?, ?it/s][A
  0%|          | 1/321 [00:00<01:29,  3.57it/s][A
  1%|          | 2/321 [00:00<01:14,  4.30it/s][A
  1%|          | 3/321 [00:00<01:08,  4.67it/s][A
  1%|          | 4/321 [00:00<01:06,  4.79it/s][A
  2%|▏         | 5/321 [00:01<01:04,  4.92it/s][A
  2%|▏         | 6/321 [00:01<01:03,  4.93it/s][A
  2%|▏         | 7/321 [00:01<01:02,  5.00it/s][A
  2%|▏         | 8/321 [00:01<01:01,  5.05it/s][A
  3%|▎         | 9/321 [00:01<01:01,  5.08it/s][A
  3%|▎         | 10/321 [00:02<01:00,  5.12it/s][A
  3%|▎         | 11/321 [00:02<01:01,  5.08it/s][A
  4%|▎         | 12/321 [00:02<00:59,  5.16it/s][A
  4%|▍         | 13/321 [00:02<00:59,  5.21it/s][A
  4%|▍         | 14/321 [00:02<00:58,  5.26it/s][A
  5%|▍         | 15/321 [00:02<00:58,  5.27it/s][A
  5%|▍         | 16/321 [00:03<00:58,  5.23it/s][A
  5%|▌         | 17/321 [00:03<00:58,  5.20it/s][A
  6%|▌         | 18/321 [00:03<00:58,  5.18it/s][A
  6%|▌         | 19/321 [00:0

msmarco_stsb encoding completed in 61.65 seconds
Generated 41070 embeddings with dimension 768
Processing speed: 666.15 docs/sec
Normalizing msmarco_stsb vectors...
Building msmarco_stsb Flat index...
msmarco_stsb Flat index built with 41070 vectors
Building msmarco_stsb HNSW index...
msmarco_stsb HNSW index built with 41070 vectors
Building msmarco_stsb IVF-PQ index...
msmarco_stsb IVF-PQ index built with 41070 vectors

Processing stsb model...
Generating stsb embeddings...



  0%|          | 0/321 [00:00<?, ?it/s][A
  0%|          | 1/321 [00:00<01:11,  4.48it/s][A
  1%|          | 2/321 [00:00<01:07,  4.75it/s][A
  1%|          | 3/321 [00:00<01:04,  4.94it/s][A
  1%|          | 4/321 [00:00<01:03,  4.97it/s][A
  2%|▏         | 5/321 [00:01<01:02,  5.04it/s][A
  2%|▏         | 6/321 [00:01<01:02,  5.01it/s][A
  2%|▏         | 7/321 [00:01<01:02,  5.05it/s][A
  2%|▏         | 8/321 [00:01<01:01,  5.08it/s][A
  3%|▎         | 9/321 [00:01<01:01,  5.10it/s][A
  3%|▎         | 10/321 [00:01<01:00,  5.13it/s][A
  3%|▎         | 11/321 [00:02<01:00,  5.10it/s][A
  4%|▎         | 12/321 [00:02<00:59,  5.18it/s][A
  4%|▍         | 13/321 [00:02<00:58,  5.23it/s][A
  4%|▍         | 14/321 [00:02<00:58,  5.28it/s][A
  5%|▍         | 15/321 [00:02<00:57,  5.30it/s][A
  5%|▍         | 16/321 [00:03<00:57,  5.28it/s][A
  5%|▌         | 17/321 [00:03<00:57,  5.25it/s][A
  6%|▌         | 18/321 [00:03<00:58,  5.22it/s][A
  6%|▌         | 19/321 [00:0

stsb encoding completed in 61.69 seconds
Generated 41070 embeddings with dimension 768
Processing speed: 665.73 docs/sec
Normalizing stsb vectors...
Building stsb Flat index...
stsb Flat index built with 41070 vectors
Building stsb HNSW index...
stsb HNSW index built with 41070 vectors
Building stsb IVF-PQ index...
stsb IVF-PQ index built with 41070 vectors

All model indexes successfully built


In [47]:
# 6. Save indexes and corpus information
print("Saving all model indexes...")
for model_name in models:
    if model_name not in all_indexes:
        print(f"Skipping {model_name} - no indexes built")
        continue

    model_dir = os.path.join(PROJECT_PATH, "models/indexes", model_name)
    os.makedirs(model_dir, exist_ok=True)

    model_indexes = all_indexes[model_name]

    # Save all index types
    print(f"\nSaving {model_name} indexes...")

    try:
        print(f"Saving {model_name} Flat index...")
        faiss.write_index(model_indexes["flat"], os.path.join(model_dir, "flat_index.faiss"))

        print(f"Saving {model_name} HNSW index...")
        faiss.write_index(model_indexes["hnsw"], os.path.join(model_dir, "hnsw_index.faiss"))

        print(f"Saving {model_name} IVF-PQ index...")
        faiss.write_index(model_indexes["ivfpq"], os.path.join(model_dir, "ivfpq_index.faiss"))

        # Save index configuration information
        with open(os.path.join(model_dir, "index_config.json"), 'w') as f:
            json.dump(all_configs[model_name], f)

        print(f"{model_name} indexes successfully saved to: {model_dir}")
    except Exception as e:
        print(f"Error saving indexes for {model_name}: {str(e)}")

# Save the corpus and document IDs
corpus_dir = os.path.join(PROJECT_PATH, "models/indexes")
print("\nSaving corpus data...")

try:
    # Save corpus
    with open(os.path.join(corpus_dir, "corpus.json"), 'w') as f:
        json.dump(corpus, f)

    # Save doc IDs
    with open(os.path.join(corpus_dir, "doc_ids.json"), 'w') as f:
        json.dump(doc_ids, f)

    # Save BM25 model using pickle
    print("Saving BM25 model...")
    with open(os.path.join(corpus_dir, "bm25_model.pkl"), 'wb') as f:
        pickle.dump(bm25, f)

    # Also save tokenized corpus for future use
    with open(os.path.join(corpus_dir, "tokenized_corpus.pkl"), 'wb') as f:
        pickle.dump(tokenized_corpus, f)

except Exception as e:
    print(f"Error saving corpus data: {str(e)}")

print("\nAll indexes and data successfully saved")

Saving all model indexes...

Saving msmarco_stsb indexes...
Saving msmarco_stsb Flat index...
Saving msmarco_stsb HNSW index...
Saving msmarco_stsb IVF-PQ index...
msmarco_stsb indexes successfully saved to: /content/drive/MyDrive/CS6120_project/models/indexes/msmarco_stsb

Saving stsb indexes...
Saving stsb Flat index...
Saving stsb HNSW index...
Saving stsb IVF-PQ index...
stsb indexes successfully saved to: /content/drive/MyDrive/CS6120_project/models/indexes/stsb

Saving corpus data...
Saving BM25 model...

All indexes and data successfully saved


In [48]:
# 7. Enhanced Retrieval Functions

def bm25_retrieve(query: str, bm25_index, doc_ids: List[str], k: int = 10) -> Tuple[List[str], List[float]]:
    """Retrieve documents using BM25 with improved error handling

    Args:
        query: Query string
        bm25_index: BM25 index
        doc_ids: List of document IDs
        k: Number of results to return

    Returns:
        top_doc_ids: List of document IDs
        top_scores: List of corresponding BM25 scores
    """
    try:
        # Process query
        query_tokens = preprocess_text(query)
        if not query_tokens:
            print(f"Warning: Empty query tokens for query: '{query}'")
            query_tokens = ["_empty_"]

        # Get BM25 scores
        bm25_scores = bm25_index.get_scores(query_tokens)

        # Ensure length compatibility
        if len(bm25_scores) != len(doc_ids):
            print(f"Warning: BM25 scores count ({len(bm25_scores)}) does not match doc_ids count ({len(doc_ids)})")
            if len(bm25_scores) < len(doc_ids):
                # Extend with zeros
                bm25_scores = np.concatenate([bm25_scores, np.zeros(len(doc_ids) - len(bm25_scores))])
            else:
                # Truncate
                bm25_scores = bm25_scores[:len(doc_ids)]

        # Get top k results
        if len(bm25_scores) == 0:
            print("Warning: No BM25 scores available")
            return [], []

        # Sort by score in descending order
        top_indices = np.argsort(bm25_scores)[::-1][:k]
        top_scores = bm25_scores[top_indices]

        # Map indices to document IDs
        top_doc_ids = [doc_ids[idx] for idx in top_indices]

        return top_doc_ids, top_scores

    except Exception as e:
        print(f"Error in BM25 retrieval: {str(e)}")
        # Return empty results in case of error
        return [], []


# Query classifier for dynamic weighting
def train_query_classifier():
    """Train a simple query classifier for dynamic strategy selection"""
    # Sample queries for training
    test_queries = [
        "How does social media affect mental health?",
        "Best programming languages to learn",
        "Artificial intelligence applications",
        "Climate change solutions and mitigation strategies",
        "Nutrition advice for athletes performance"
    ]

    # Feature extraction function
    def extract_query_features(query):
        features = []
        # Length features
        features.append(len(query))  # Raw character length
        features.append(len(query.split()))  # Word count

        # Question features
        features.append(1 if "?" in query else 0)  # Is it a question
        features.append(1 if query.lower().startswith("how") else 0)  # How question
        features.append(1 if query.lower().startswith("what") else 0)  # What question
        features.append(1 if query.lower().startswith("why") else 0)  # Why question

        # Structure features
        features.append(1 if "," in query else 0)  # Has comma
        features.append(1 if ":" in query else 0)  # Has colon

        # Topic features (simple keyword checking)
        features.append(1 if any(word in query.lower() for word in ["programming", "code", "software"]) else 0)  # Tech
        features.append(1 if any(word in query.lower() for word in ["health", "medical", "disease"]) else 0)  # Health

        return features

    # Build training features
    X_train = np.array([extract_query_features(q) for q in test_queries])
    # Dummy labels (which model is better)
    # 0 for primary model (semantic-heavy queries)
    # 1 for fallback model (simpler queries that might work better with BM25)
    y_train = np.array([0, 1, 0, 0, 1])  # Simulated labels

    # Train a simple query classifier
    classifier = LogisticRegression(random_state=42)
    classifier.fit(X_train, y_train)
    print("Query classifier trained")

    # Add feature extractor to classifier object for convenience
    classifier.extract_features = extract_query_features

    return classifier

# Train the classifier
query_classifier = train_query_classifier()

Query classifier trained


In [49]:
# 8. Main Enhanced Hybrid Retrieval Function
def hybrid_retrieve_documents(
    query: str,
    query_id: str,
    top_k: int = 100,
    strategy: str = "sbert_bm25",
    alpha: float = 0.7,
    debug: bool = False
) -> Dict[str, Dict[str, float]]:
    """Enhanced hybrid retrieval function with improved error handling

    Args:
        query: Query string
        query_id: Query ID
        top_k: Number of results to return
        strategy: Retrieval strategy
        alpha: Weight for semantic score (1-alpha for BM25)
        debug: Whether to print debug information

    Returns:
        Dictionary {query_id: {doc_id: score}} for MRR evaluation
    """
    # Validate inputs
    if not query:
        print("Error: Empty query")
        return {query_id: {}}  # Return empty results

    if strategy not in ["single", "dynamic", "fallback", "ensemble", "sbert_bm25"]:
        print(f"Error: Unknown strategy '{strategy}', falling back to 'single' strategy")
        strategy = "single"

    # Initialize result format for MRR evaluation
    result = {query_id: {}}

    # Start retrieval process with full error handling
    try:
        # Initialize model list based on strategy
        if strategy in ["single", "fallback"]:
            model_list = [primary_model]
        else:
            model_list = [primary_model, fallback_model]

        # Encode query for each model
        query_embeddings = {}
        for model_name in model_list:
            if model_name not in models:
                print(f"Error: Model {model_name} not loaded")
                continue

            query_emb = models[model_name].encode([query])
            faiss.normalize_L2(query_emb)
            query_embeddings[model_name] = query_emb

        # If no query embeddings could be generated, return empty results
        if not query_embeddings:
            print("Error: No query embeddings could be generated")
            return {query_id: {}}

        # Process based on strategy
        if strategy == "single":
            # Just use primary model
            if primary_model not in all_indexes:
                print(f"Error: No indexes for {primary_model}")
                return {query_id: {}}

            index = all_indexes[primary_model]["hnsw"]
            D, I = index.search(query_embeddings[primary_model], top_k)

            # Build results in required format
            for i in range(min(top_k, len(I[0]))):
                if I[0][i] >= 0 and I[0][i] < len(doc_ids):  # Ensure valid index
                    doc_id = doc_ids[I[0][i]]
                    score = float(D[0][i])
                    result[query_id][doc_id] = score

        elif strategy == "dynamic":
            # Use query classifier to determine best model
            features = query_classifier.extract_features(query)
            features = np.array([features])

            # Predict which model to use
            model_idx = query_classifier.predict(features)[0]
            model_to_use = primary_model if model_idx == 0 else fallback_model

            if debug:
                print(f"Dynamic strategy selected model: {model_to_use}")

            if model_to_use not in all_indexes or model_to_use not in query_embeddings:
                # Fall back to primary model if selected model is not available
                print(f"Warning: Selected model {model_to_use} not available, using {primary_model} instead")
                model_to_use = primary_model

            # Use the selected model
            index = all_indexes[model_to_use]["hnsw"]
            D, I = index.search(query_embeddings[model_to_use], top_k)

            # Build results in required format
            for i in range(min(top_k, len(I[0]))):
                if I[0][i] >= 0 and I[0][i] < len(doc_ids):
                    doc_id = doc_ids[I[0][i]]
                    score = float(D[0][i])
                    result[query_id][doc_id] = score

        elif strategy == "fallback":
            # First try with primary model
            primary_index = all_indexes[primary_model]["hnsw"]
            D_primary, I_primary = primary_index.search(query_embeddings[primary_model], top_k)

            # Check confidence (average similarity score)
            confidence = np.mean(D_primary[0]) if len(D_primary[0]) > 0 else 0
            threshold = 0.3  # Confidence threshold

            if debug:
                print(f"Primary model confidence: {confidence:.4f} (threshold: {threshold})")

            # Decide which results to use
            if confidence > threshold:
                # Use primary model results
                D, I = D_primary, I_primary
                if debug:
                    print("Using primary model results")
            else:
                # Switch to fallback model
                if debug:
                    print(f"Switching to fallback model ({fallback_model})")

                fallback_index = all_indexes[fallback_model]["hnsw"]
                D, I = fallback_index.search(query_embeddings[fallback_model], top_k)

            # Build results in required format
            for i in range(min(top_k, len(I[0]))):
                if I[0][i] >= 0 and I[0][i] < len(doc_ids):
                    doc_id = doc_ids[I[0][i]]
                    score = float(D[0][i])
                    result[query_id][doc_id] = score

        elif strategy == "ensemble":
            # Get results from each model
            all_results = {}

            for model_name in model_list:
                if model_name not in all_indexes or model_name not in query_embeddings:
                    continue

                model_index = all_indexes[model_name]["hnsw"]
                D, I = model_index.search(query_embeddings[model_name], top_k * 2)  # Get more candidates

                # Save score for each document ID
                for j in range(len(I[0])):
                    idx = int(I[0][j])
                    if idx < 0 or idx >= len(doc_ids):  # Skip invalid indices
                        continue

                    doc_id = doc_ids[idx]
                    score = float(D[0][j])

                    if doc_id not in all_results:
                        all_results[doc_id] = {}

                    all_results[doc_id][model_name] = score

            # Compute combined scores using weights
            weights = {primary_model: alpha, fallback_model: 1.0-alpha}
            final_scores = {}

            for doc_id in all_results:
                final_scores[doc_id] = 0
                for model_name, weight in weights.items():
                    if model_name in all_results[doc_id]:
                        final_scores[doc_id] += all_results[doc_id][model_name] * weight

            # Sort and select top_k results
            sorted_results = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]

            # Build results in required format
            for doc_id, score in sorted_results:
                result[query_id][doc_id] = float(score)

        elif strategy == "sbert_bm25":
            # SBERT + BM25 hybrid approach - most thorough option
            if debug:
                print(f"Running sbert_bm25 strategy with alpha={alpha:.2f}")

            # 1. Get SBERT results
            sbert_index = all_indexes[primary_model]["hnsw"]
            D_sbert, I_sbert = sbert_index.search(query_embeddings[primary_model], top_k*2)

            # Get SBERT results in dictionary format
            sbert_results = {}
            for j in range(len(I_sbert[0])):
                idx = int(I_sbert[0][j])
                if idx < 0 or idx >= len(doc_ids):
                    continue

                doc_id = doc_ids[idx]
                score = float(D_sbert[0][j])
                sbert_results[doc_id] = score

            # 2. Get BM25 results
            bm25_doc_ids, bm25_scores = bm25_retrieve(query, bm25, doc_ids, k=top_k*2)

            # Get BM25 results in dictionary format
            bm25_results = {}
            for doc_id, score in zip(bm25_doc_ids, bm25_scores):
                bm25_results[doc_id] = float(score)

            if debug:
                print(f"SBERT found {len(sbert_results)} results")
                print(f"BM25 found {len(bm25_results)} results")

            # 3. Check if we have results
            if not sbert_results and not bm25_results:
                print("Warning: No results found from either SBERT or BM25")
                return {query_id: {}}

            # 4. Normalize scores
            if sbert_results and len(sbert_results) > 0:
                sbert_docs = list(sbert_results.keys())
                sbert_scores = np.array(list(sbert_results.values()))
                sbert_scores_norm = normalize_scores(sbert_scores)
                sbert_results = dict(zip(sbert_docs, sbert_scores_norm))

            if bm25_results and len(bm25_results) > 0:
                bm25_docs = list(bm25_results.keys())
                bm25_scores = np.array(list(bm25_results.values()))
                bm25_scores_norm = normalize_scores(bm25_scores)
                bm25_results = dict(zip(bm25_docs, bm25_scores_norm))

            # 5. Combine unique candidates
            all_candidates = set(sbert_results.keys()) | set(bm25_results.keys())

            if debug:
                print(f"Total unique candidates: {len(all_candidates)}")

            # 6. Calculate combined scores
            combined_scores = {}
            for doc_id in all_candidates:
                sbert_score = sbert_results.get(doc_id, 0.0)
                bm25_score = bm25_results.get(doc_id, 0.0)
                # Apply scaling factor and combine
                combined_scores[doc_id] = alpha * sbert_score + (1-alpha) * bm25_score

            # 7. Sort and take top k
            ranked_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]

            # 8. Build results in required format
            for doc_id, score in ranked_results:
                result[query_id][doc_id] = float(score)

    except Exception as e:
        print(f"Error in hybrid retrieval: {str(e)}")
        # Return empty results in case of error

    return result

In [59]:
# 9. Enhanced Evaluation Functions - 修改版

def compute_mrr_at_k(run: Dict[str, Dict[str, float]],
                     qrels: Dict[str, Dict[str, int]],
                     k: int = 100,
                     verbose: bool = False) -> float:
    """Compute MRR@k with improved error handling

    Args:
        run: Search results {query_id: {doc_id: score}}
        qrels: Ground truth {query_id: {doc_id: relevance}}
        k: Cutoff value
        verbose: Whether to print detailed information

    Returns:
        MRR value
    """
    total_rr = 0.0
    num_queries = 0
    queries_with_relevant_docs = 0
    queries_with_relevant_retrieved = 0

    for qid, relevant_docs in qrels.items():
        # Skip if no relevant docs
        if not relevant_docs:
            continue

        # Count this as a query with relevant docs
        queries_with_relevant_docs += 1

        # Skip if query not in run
        if qid not in run:
            if verbose:
                print(f"Query {qid} not in run")
            continue

        # Skip if run has no results for this query
        if not run[qid]:
            if verbose:
                print(f"No results for query {qid}")
            continue

        # Sort results by score in descending order, take top k
        try:
            sorted_docs = sorted(run[qid].items(), key=lambda x: x[1], reverse=True)[:k]
        except Exception as e:
            print(f"Error sorting results for query {qid}: {str(e)}")
            print(f"Run[qid] type: {type(run[qid])}, value: {run[qid]}")
            continue

        rr = 0.0  # Reciprocal rank for current query
        for rank, (doc_id, score) in enumerate(sorted_docs, start=1):
            # If document is relevant
            if doc_id in relevant_docs and relevant_docs[doc_id] > 0:
                rr = 1.0 / rank
                queries_with_relevant_retrieved += 1
                break  # Only consider first relevant document

        total_rr += rr
        num_queries += 1

    if verbose:
        print(f"Queries processed: {num_queries}")
        print(f"Queries with relevant docs: {queries_with_relevant_docs}")
        print(f"Queries with relevant docs retrieved: {queries_with_relevant_retrieved}")

    if num_queries == 0:
        print("Warning: No queries with results were evaluated")
        return 0.0

    return total_rr / num_queries


def compute_recall_at_k(run: Dict[str, Dict[str, float]],
                       qrels: Dict[str, Dict[str, int]],
                       k: int = 100,
                       verbose: bool = False) -> float:
    """Compute Recall@K with improved error handling

    Args:
        run: Search results {query_id: {doc_id: score}}
        qrels: Ground truth {query_id: {doc_id: relevance}}
        k: Cutoff value
        verbose: Whether to print detailed information

    Returns:
        Recall value
    """
    total_recall = 0.0
    num_queries_with_rels = 0  # Only count queries with relevant documents
    recall_values = []

    for qid, rel_docs in qrels.items():
        # Get relevant document set
        relevant_docs = {doc_id for doc_id, rel in rel_docs.items() if rel > 0}
        if not relevant_docs:
            # Skip queries without relevant documents
            continue

        # Count this query
        num_queries_with_rels += 1

        # Skip if query not in run
        if qid not in run:
            recall_values.append(0.0)
            continue

        # Skip if run has no results for this query
        if not run[qid]:
            recall_values.append(0.0)
            continue

        try:
            # Get top k documents by score
            top_docs = sorted(run[qid].items(), key=lambda x: x[1], reverse=True)[:k]
            top_docs_ids = {doc_id for doc_id, score in top_docs}

            # Compute recall: hits / total relevant
            hit_count = len(relevant_docs & top_docs_ids)
            recall_q = hit_count / len(relevant_docs)

            recall_values.append(recall_q)
            total_recall += recall_q

            if verbose and recall_q == 1.0:
                print(f"Query {qid} has perfect recall")

        except Exception as e:
            print(f"Error computing recall for query {qid}: {str(e)}")
            recall_values.append(0.0)

    if verbose:
        print(f"Queries with relevant docs: {num_queries_with_rels}")
        if recall_values:
            print(f"Min recall: {min(recall_values):.4f}")
            print(f"Max recall: {max(recall_values):.4f}")
            print(f"Median recall: {np.median(recall_values):.4f}")

    if num_queries_with_rels == 0:
        print("Warning: No queries with relevant documents were evaluated")
        return 0.0

    return total_recall / num_queries_with_rels


# Check if pytrec_eval is available and define NDCG function
try:
    import pytrec_eval

    def compute_ndcg_at_k(run, qrels, k=100, verbose=False):
        """Compute NDCG@k using pytrec_eval"""
        try:
            # Create evaluator
            evaluator = pytrec_eval.RelevanceEvaluator(qrels, {'ndcg'})

            # Evaluate
            results = evaluator.evaluate(run)

            # Extract NDCG scores
            ndcg_values = [results[qid].get('ndcg', 0.0) for qid in results]

            if verbose:
                print(f"Queries evaluated for NDCG: {len(ndcg_values)}")
                if ndcg_values:
                    print(f"Min NDCG: {min(ndcg_values):.4f}")
                    print(f"Max NDCG: {max(ndcg_values):.4f}")
                    print(f"Median NDCG: {np.median(ndcg_values):.4f}")

            # Calculate mean NDCG
            if not ndcg_values:
                return 0.0

            return sum(ndcg_values) / len(ndcg_values)
        except Exception as e:
            print(f"Error computing NDCG: {str(e)}")
            return 0.0

    has_pytrec_eval = True
    print("pytrec_eval loaded successfully for NDCG calculation")

except ImportError:
    has_pytrec_eval = False
    print("pytrec_eval not available, skipping NDCG calculation")

pytrec_eval loaded successfully for NDCG calculation


In [60]:
# 测试检索功能：运行这个单元格来获取输出结果
# 测试单个查询示例
print("Testing a sample query...")
sample_result = test_sample_query(strategy="sbert_bm25", alpha=0.0, top_k=5)

# 测试BM25检索功能
print("\n--- Testing BM25 Retrieval ---")
sample_qid = random.choice(list(queries.keys()))
sample_query = queries[sample_qid]
print(f"Query: {sample_query}")

bm25_doc_ids, bm25_scores = bm25_retrieve(sample_query, bm25, doc_ids, k=5)
print("\nBM25 retrieval results:")
for i, (doc_id, score) in enumerate(zip(bm25_doc_ids, bm25_scores)):
    print(f"Result {i+1}: {doc_id} (Score: {score:.4f})")
    print(f"Text: {corpus[doc_id][:100]}...")

# 验证MRR计算
print("\n--- Verifying MRR Calculation ---")
# 创建一个小型测试集
test_qid = sample_qid
test_query = sample_query
test_run = {}
test_run[test_qid] = {}

# 添加一些检索结果
test_results = hybrid_retrieve_documents(
    test_query, test_qid, top_k=10, strategy="sbert_bm25", alpha=0.0
)

print(f"Run format for MRR calculation: {type(test_results)}")
print(f"Example run entry: {list(test_results.items())[0]}")

# 计算MRR
if test_qid in qrels:
    test_mrr = compute_mrr_at_k({test_qid: test_results[test_qid]}, {test_qid: qrels[test_qid]}, k=10, verbose=True)
    print(f"Test MRR@10: {test_mrr:.4f}")
else:
    print(f"Query {test_qid} has no relevance judgments")

    # 找一个有相关文档的查询
    for alt_qid in qrels:
        if qrels[alt_qid]:
            print(f"Using alternative query {alt_qid} with {len(qrels[alt_qid])} relevant docs")
            alt_query = queries[alt_qid]
            alt_results = hybrid_retrieve_documents(
                alt_query, alt_qid, top_k=10, strategy="sbert_bm25", alpha=0.0
            )
            alt_mrr = compute_mrr_at_k(alt_results, {alt_qid: qrels[alt_qid]}, k=10, verbose=True)
            print(f"Alternative test MRR@10: {alt_mrr:.4f}")
            break

Testing a sample query...
Sample query ID: 17999
Sample query: does starving yourself cause your belly to bloat
Running sbert_bm25 strategy with alpha=0.00
SBERT found 10 results
BM25 found 10 results
Total unique candidates: 17

Retrieved documents using sbert_bm25 strategy (alpha=0.0):

Rank 1 - Doc ID: 17999_6 - Score: 1.0000 - ✗ NOT RELEVANT
Preview: No, you’re not imagining things—you have the dreaded belly bloat! And even if you’re going for daily runs and eating right, many factors in your day can contribute to sudden and unexpected belly bloat...

Rank 2 - Doc ID: 17999_8 - Score: 0.9144 - ✗ NOT RELEVANT
Preview: Bloat from carbonated water. For one to three hours after drinking carbonated water, you may feel as though your belly has expanded. The carbonation can make your stomach look distended and cause clot...

Rank 3 - Doc ID: 17999_1 - Score: 0.8984 - ✗ NOT RELEVANT
Preview: The starvation belly stands out in painful relief against emaciated arms, legs, and face, and will 

In [61]:
# 评估不同检索策略（小样本）
print("Evaluating retrieval strategies on a small sample...")
evaluation_results = evaluate_strategies(
    max_queries=5,  # 使用较小的样本量便于快速运行
    k_values=[10, 50, 100],
    strategies=["single", "sbert_bm25"]  # 只评估两种最重要的策略
)

# 显示结果摘要
print("\nResults summary:")
for strategy, results in evaluation_results.items():
    print(f"\nStrategy: {strategy}")
    for k, metrics in results.items():
        print(f"  k={k}: MRR={metrics['mrr']:.4f}, Recall={metrics['recall']:.4f}" +
              (f", NDCG={metrics['ndcg']:.4f}" if 'ndcg' in metrics and metrics['ndcg'] > 0 else ""))

Evaluating retrieval strategies on a small sample...
Evaluating 5 queries across 2 strategies
Queries with relevance judgments: 5/5

Evaluating strategy: single


Processing queries with single: 100%|██████████| 5/5 [00:00<00:00, 78.31it/s]


  k=10: MRR=0.0000, Recall=0.0000, NDCG=0.1588
  k=50: MRR=0.0000, Recall=0.0000, NDCG=0.1588
  k=100: MRR=0.0107, Recall=1.0000, NDCG=0.1588

Evaluating strategy: sbert_bm25


Processing queries with sbert_bm25: 100%|██████████| 5/5 [00:00<00:00, 14.54it/s]

  k=10: MRR=0.0000, Recall=0.0000, NDCG=0.0000
  k=50: MRR=0.0000, Recall=0.0000, NDCG=0.0000
  k=100: MRR=0.0000, Recall=0.0000, NDCG=0.0000

Summary of Results:

single:
  k=10: mrr=0.0000, recall=0.0000, ndcg=0.1588
  k=50: mrr=0.0000, recall=0.0000, ndcg=0.1588
  k=100: mrr=0.0107, recall=1.0000, ndcg=0.1588

sbert_bm25:
  k=10: mrr=0.0000, recall=0.0000, ndcg=0.0000
  k=50: mrr=0.0000, recall=0.0000, ndcg=0.0000
  k=100: mrr=0.0000, recall=0.0000, ndcg=0.0000

Results summary:

Strategy: single
  k=10: MRR=0.0000, Recall=0.0000, NDCG=0.1588
  k=50: MRR=0.0000, Recall=0.0000, NDCG=0.1588
  k=100: MRR=0.0107, Recall=1.0000, NDCG=0.1588

Strategy: sbert_bm25
  k=10: MRR=0.0000, Recall=0.0000
  k=50: MRR=0.0000, Recall=0.0000
  k=100: MRR=0.0000, Recall=0.0000





In [62]:
# 测试alpha参数调优（使用非常小的样本以便快速运行）
print("Testing alpha parameter tuning on a very small sample...")
mini_alpha_results = tune_alpha_parameter(sample_size=3)

Testing alpha parameter tuning on a very small sample...
Testing sbert_bm25 strategy with different alpha values...

Alpha = 0.0:


Processing queries with alpha=0.0: 100%|██████████| 3/3 [00:00<00:00, 11.96it/s]


  MRR@100=0.1944, Recall@100=1.0000

Alpha = 0.2:


Processing queries with alpha=0.2: 100%|██████████| 3/3 [00:00<00:00, 11.68it/s]


  MRR@100=0.1784, Recall@100=1.0000

Alpha = 0.4:


Processing queries with alpha=0.4: 100%|██████████| 3/3 [00:00<00:00, 13.02it/s]


  MRR@100=0.0998, Recall@100=1.0000

Alpha = 0.5:


Processing queries with alpha=0.5: 100%|██████████| 3/3 [00:00<00:00, 11.23it/s]


  MRR@100=0.0078, Recall@100=0.3333

Alpha = 0.6:


Processing queries with alpha=0.6: 100%|██████████| 3/3 [00:00<00:00, 13.89it/s]


  MRR@100=0.0000, Recall@100=0.0000

Alpha = 0.8:


Processing queries with alpha=0.8: 100%|██████████| 3/3 [00:00<00:00, 13.97it/s]


  MRR@100=0.0000, Recall@100=0.0000

Alpha = 1.0:


Processing queries with alpha=1.0: 100%|██████████| 3/3 [00:00<00:00, 13.93it/s]

  MRR@100=0.0000, Recall@100=0.0000

Summary of Alpha Results for sbert_bm25 strategy:
Alpha	MRR@100	Recall@100
0.0	0.1944	1.0000
0.2	0.1784	1.0000
0.4	0.0998	1.0000
0.5	0.0078	0.3333
0.6	0.0000	0.0000
0.8	0.0000	0.0000
1.0	0.0000	0.0000

Best alpha value for MRR: 0.0 (MRR@100=0.1944)
Best alpha value for Recall: 0.0 (Recall@100=1.0000)



