In [10]:
import h5py
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict
from tqdm import tqdm

@dataclass
class QuerySet:
    qid_to_query: Dict[str, str]
    qid_to_embedding: Dict[str, np.ndarray]
    qid_to_relevance: Dict[str, Dict[str, int]]

def load_embeddings(file_path: str, id_key='id', embedding_key='embedding'):
    """Load embeddings from H5 file."""
    with h5py.File(file_path, 'r') as f:
        # Print available keys in file for debugging
        print(f"Available keys in {file_path}: {list(f.keys())}")
        ids = np.array(f[id_key]).astype(str)
        embeddings = np.array(f[embedding_key]).astype(np.float32)
    print(f"Loaded {len(ids)} embeddings from {file_path}")
    return ids, embeddings

class MSMARCOData:
    def __init__(self, data_dir: str):
        self.data_dir = data_dir
        self.passages = None
        self.passage_embeddings = None
        self.eval_queries = None
        self.dev_queries = None
    
    def load_all(self):
        print("Loading MSMARCO data...")
        with tqdm(total=4, desc="Loading components") as pbar:
            self.load_passages()
            pbar.update(1)
            self.load_passage_embeddings()
            pbar.update(1)
            self.load_eval_queries()
            pbar.update(1)
            self.load_dev_queries()
            pbar.update(1)
    
    def load_passages(self):
        self.passages = pd.read_csv(
            f"{self.data_dir}/data.tsv",
            sep='\t', 
            names=['docid', 'text'],
            dtype={'docid': str}
        )
    
    def load_passage_embeddings(self):
        try:
            ids, embeddings = load_embeddings(
                f"{self.data_dir}/embeddings.h5",
                'id',  # Changed from 'ids' to 'id'
                'embedding'  # Changed from 'embeddings' to 'embedding'
            )
            self.passage_embeddings = dict(zip(ids, embeddings))
        except Exception as e:
            print(f"Error loading passage embeddings: {e}")
            raise
    
    def load_query_embeddings(self) -> Dict[str, np.ndarray]:
        ids, embeddings = load_embeddings(
            f"{self.data_dir}/queries_dev_eval_embeddings.h5",
            'id',  # Changed from 'ids' to 'id'
            'embedding'  # Changed from 'embeddings' to 'embedding'
        )
        return dict(zip(ids, embeddings))
    
    def load_eval_queries(self):
        # Load queries and relevance judgments
        queries = pd.read_csv(
            f"{self.data_dir}/queries.eval.tsv",
            sep='\t', 
            names=['qid', 'query'],
            dtype={'qid': str}
        )
        
        qrels = pd.concat([
            pd.read_csv(f"{self.data_dir}/qrels.eval.one.tsv", 
                       sep='\t', names=['qid', '_', 'docid', 'relevance'],
                       dtype={'qid': str, 'docid': str}),
            pd.read_csv(f"{self.data_dir}/qrels.eval.two.tsv", 
                       sep='\t', names=['qid', '_', 'docid', 'relevance'],
                       dtype={'qid': str, 'docid': str})
        ])
        
        query_embeddings = self.load_query_embeddings()
        
        self.eval_queries = QuerySet(
            qid_to_query={row.qid: row.query for _, row in queries.iterrows()},
            qid_to_embedding=query_embeddings,
            qid_to_relevance={
                qid: {row.docid: row.relevance for _, row in group.iterrows()}
                for qid, group in qrels.groupby('qid')
            }
        )
    
    def load_dev_queries(self):
        queries = pd.read_csv(
            f"{self.data_dir}/queries.dev.tsv",
            sep='\t', 
            names=['qid', 'query'],
            dtype={'qid': str}
        )
        
        qrels = pd.read_csv(
            f"{self.data_dir}/qrels.dev.tsv",
            sep='\t', 
            names=['qid', '_', 'docid', 'relevance'],
            dtype={'qid': str, 'docid': str}
        )
        
        query_embeddings = self.load_query_embeddings()
        
        self.dev_queries = QuerySet(
            qid_to_query={row.qid: row.query for _, row in queries.iterrows()},
            qid_to_embedding=query_embeddings,
            qid_to_relevance={
                qid: {row.docid: row.relevance for _, row in group.iterrows()}
                for qid, group in qrels.groupby('qid')
            }
        )
    
    def get_passage(self, docid: str) -> str:
        try:
            return self.passages[self.passages['docid'] == docid]['text'].iloc[0]
        except IndexError:
            raise KeyError(f"Passage not found for docid: {docid}")
    
    def get_passage_embedding(self, docid: str) -> np.ndarray:
        try:
            return self.passage_embeddings[docid]
        except KeyError:
            raise KeyError(f"Embedding not found for docid: {docid}")

In [11]:
data = MSMARCOData("/Users/ad12/Documents/Develop/wse-hw-2/data")
data.load_all()

Loading MSMARCO data...


Loading components:  25%|██▌       | 1/4 [00:11<00:33, 11.05s/it]

Available keys in /Users/ad12/Documents/Develop/wse-hw-2/data/embeddings.h5: ['embedding', 'id']


Loading components:  50%|█████     | 2/4 [00:17<00:16,  8.47s/it]

Loaded 1000000 embeddings from /Users/ad12/Documents/Develop/wse-hw-2/data/embeddings.h5
Available keys in /Users/ad12/Documents/Develop/wse-hw-2/data/queries_dev_eval_embeddings.h5: ['embedding', 'id']
Loaded 202185 embeddings from /Users/ad12/Documents/Develop/wse-hw-2/data/queries_dev_eval_embeddings.h5


Loading components:  75%|███████▌  | 3/4 [00:24<00:07,  7.75s/it]

Available keys in /Users/ad12/Documents/Develop/wse-hw-2/data/queries_dev_eval_embeddings.h5: ['embedding', 'id']
Loaded 202185 embeddings from /Users/ad12/Documents/Develop/wse-hw-2/data/queries_dev_eval_embeddings.h5


Loading components: 100%|██████████| 4/4 [00:30<00:00,  7.53s/it]


In [None]:
from rank_bm25 import BM25Okapi
import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass
import nltk
from nltk.tokenize import word_tokenize
import faiss

# Download required NLTK data
nltk.download('punkt')

@dataclass
class SearchResult:
    docid: str
    score: float
    rank: int

class PurePythonSearchSystem:
    def __init__(self, msmarco_data: MSMARCOData):
        self.data = msmarco_data
        self.bm25 = None
        self.docids = None
        self.hnsw_index = None
        self.tokenized_corpus = None
        
    def build_bm25_index(self):
        """Build BM25 index using pure Python implementation"""
        print("Building BM25 index...")
        
        # Store docids in order
        self.docids = self.data.passages['docid'].values
        
        # Tokenize corpus
        self.tokenized_corpus = [
            word_tokenize(text.lower())
            for text in self.data.passages['text']
        ]
        
        # Create BM25 index
        self.bm25 = BM25Okapi(self.tokenized_corpus)
        print("BM25 index built successfully")
        
    def build_hnsw_index(self, m: int = 8, ef_construction: int = 100):
        """Build HNSW index for vector search"""
        print("Building HNSW index...")
        
        # Get embeddings in same order as passages
        embeddings = np.array([
            self.data.passage_embeddings[did] 
            for did in self.docids
        ])
        
        # Normalize embeddings
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        
        # Initialize and build HNSW index
        dimension = embeddings.shape[1]
        self.hnsw_index = faiss.IndexHNSWFlat(dimension, m)
        self.hnsw_index.hnsw.efConstruction = ef_construction
        self.hnsw_index.hnsw.efSearch = ef_construction
        
        self.hnsw_index.add(embeddings)
        print("HNSW index built successfully")
        
    def bm25_search(self, query: str, k: int = 100) -> List[SearchResult]:
        """Search using BM25"""
        if self.bm25 is None:
            raise RuntimeError("BM25 index not built. Call build_bm25_index() first.")
            
        # Tokenize query
        tokenized_query = word_tokenize(query.lower())
        
        # Get scores
        scores = self.bm25.get_scores(tokenized_query)
        
        # Get top k indices
        top_indices = np.argsort(-scores)[:k]
        
        return [
            SearchResult(
                docid=self.docids[idx],
                score=float(scores[idx]),
                rank=rank + 1
            )
            for rank, idx in enumerate(top_indices)
        ]
        
    def vector_search(self, query_embedding: np.ndarray, k: int = 100) -> List[SearchResult]:
        """Search using HNSW index"""
        if self.hnsw_index is None:
            raise RuntimeError("HNSW index not built. Call build_hnsw_index() first.")
            
        # Normalize query embedding
        query_embedding = query_embedding / np.linalg.norm(query_embedding)
        
        # Search
        distances, indices = self.hnsw_index.search(
            query_embedding.reshape(1, -1), k=k
        )
        
        return [
            SearchResult(
                docid=self.docids[int(idx)],
                score=float(1 - dist),
                rank=rank + 1
            )
            for rank, (idx, dist) in enumerate(zip(indices[0], distances[0]))
        ]
        
    def hybrid_search(self, query: str, query_embedding: np.ndarray, 
                     k: int = 100, alpha: float = 0.5) -> List[SearchResult]:
        """Hybrid search combining BM25 and dense retrieval"""
        # Get BM25 results
        bm25_results = self.bm25_search(query, k=1000)
        
        # Get embeddings for candidates
        candidate_embeddings = np.array([
            self.data.get_passage_embedding(res.docid) 
            for res in bm25_results
        ])
        
        # Normalize embeddings
        candidate_embeddings = candidate_embeddings / np.linalg.norm(
            candidate_embeddings, axis=1, keepdims=True
        )
        query_embedding = query_embedding / np.linalg.norm(query_embedding)
        
        # Compute dense scores
        dense_scores = np.dot(candidate_embeddings, query_embedding)
        
        # Normalize BM25 scores
        bm25_scores = np.array([res.score for res in bm25_results])
        bm25_scores = (bm25_scores - bm25_scores.min()) / (
            bm25_scores.max() - bm25_scores.min()
        )
        
        # Combine scores
        combined_scores = alpha * bm25_scores + (1 - alpha) * dense_scores
        
        # Sort and return top k
        top_indices = np.argsort(-combined_scores)[:k]
        return [
            SearchResult(
                docid=bm25_results[idx].docid,
                score=float(combined_scores[idx]),
                rank=rank + 1
            )
            for rank, idx in enumerate(top_indices)
        ]

ModuleNotFoundError: No module named 'pyserini'

In [None]:
search_system = PurePythonSearchSystem(data)
search_system.build_bm25_index()
search_system.build_hnsw_index()