# Sparse Retrieval on Halubench (BM25 and SPLADE)

In [1]:
import os, json, math, subprocess
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from scipy.sparse import csr_matrix, save_npz, load_npz

# -----------------------
# Small helpers
# -----------------------

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def dedupe_passages(df, text_col):
    """Return deduped dataframe of passages + mapping from original row -> dedup id (str)."""
    x = df.copy()
    x[text_col] = x[text_col].fillna("").astype(str)
    dedup = x[[text_col]].drop_duplicates().reset_index(drop=True)
    text_to_id = {dedup.loc[i, text_col]: str(i) for i in range(len(dedup))}
    rel_map = {}
    for i in range(len(df)):
        t = "" if pd.isna(df.loc[i, text_col]) else str(df.loc[i, text_col])
        rel_map[str(i)] = text_to_id.get(t, text_to_id.get("", "0"))
    return dedup, rel_map

def ap_at_k(ret_ids, rel_id, k):
    for rank, did in enumerate(ret_ids[:k], 1):
        if did == rel_id:
            return 1.0 / rank
    return 0.0

def ndcg_at_k(ret_ids, rel_id, k):
    for rank, did in enumerate(ret_ids[:k], 1):
        if did == rel_id:
            return 1.0 / math.log2(rank + 1)
    return 0.0

# -----------------------
# BM25 (Pyserini/Lucene)
# -----------------------

class BM25Retriever:
    def __init__(self, index_dir="./indices/bm25", threads=8, java_mem="8g"):
        self.index_dir = index_dir
        self.threads = threads
        self.java_mem = java_mem

    def build_index(self, dedup_df, text_col, work_dir="./work/bm25"):
        ensure_dir(self.index_dir); ensure_dir(work_dir)
        corpus_dir = os.path.join(work_dir, "json_corpus"); ensure_dir(corpus_dir)
        docs_path = os.path.join(corpus_dir, "docs.jsonl")
        with open(docs_path, "w", encoding="utf-8") as f:
            for i in range(len(dedup_df)):
                text = "" if pd.isna(dedup_df.loc[i, text_col]) else str(dedup_df.loc[i, text_col])
                f.write(json.dumps({"id": str(i), "contents": text}, ensure_ascii=False) + "\n")
        cmd = [
            "python","-m","pyserini.index.lucene",
            "--collection","JsonCollection",
            "--input", corpus_dir,
            "--index", self.index_dir,
            "--generator","DefaultLuceneDocumentGenerator",
            "--threads", str(self.threads),
            "--storePositions","--storeDocvectors","--storeRaw"
        ]
        env = os.environ.copy()
        env["JAVA_TOOL_OPTIONS"] = f"-Xms{self.java_mem} -Xmx{self.java_mem}"
        res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env)
        if res.returncode != 0:
            print(res.stdout); print(res.stderr)
            raise RuntimeError("BM25 indexing failed (check Java 11+ and pyserini).")

    def retrieve(self, queries, topk=10):
        from pyserini.search.lucene import LuceneSearcher
        searcher = LuceneSearcher(self.index_dir)
        results = {}
        for qid, q in tqdm(queries.items(), desc="BM25 search"):
            hits = searcher.search(q, k=topk)
            results[qid] = [(h.docid, float(h.score)) for h in hits]
        return results

# -----------------------
# SPLADE (Transformers + CSR)
# -----------------------

class SPLADERetriever:
    def __init__(self, index_dir="./indices/splade", model_name="naver/splade-cocondenser-ensembledistil",
                 batch_size=8, max_length=256, min_weight=0.01):
        self.index_dir = index_dir
        self.model_name = model_name
        self.batch_size = batch_size
        self.max_length = max_length
        self.min_weight = min_weight
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tok = None
        self.model = None

    def _load(self):
        if self.tok is None:
            self.tok = AutoTokenizer.from_pretrained(self.model_name)
        if self.model is None:
            self.model = AutoModelForMaskedLM.from_pretrained(self.model_name).to(self.device).eval()

    @torch.no_grad()
    def _encode_texts(self, texts):
        self._load()
        V = self.model.config.vocab_size
        data, indices, indptr = [], [], [0]
        for i in tqdm(range(0, len(texts), self.batch_size), desc="SPLADE encode"):
            batch = texts[i:i+self.batch_size]
            toks = self.tok(batch, return_tensors="pt", padding=True, truncation=True,
                            max_length=self.max_length).to(self.device)
            logits = self.model(**toks).logits  # [B,L,V]
            activ = torch.log1p(torch.relu(logits))
            weights = activ.max(dim=1).values.cpu().numpy()  # [B,V]
            for row in weights:
                nz = np.where(row >= self.min_weight)[0]
                indices.extend(nz.tolist())
                data.extend(row[nz].astype(np.float32).tolist())
                indptr.append(len(indices))
        return csr_matrix((np.array(data, np.float32),
                           np.array(indices, np.int32),
                           np.array(indptr, np.int32)),
                          shape=(len(texts), V), dtype=np.float32)

    def build_index(self, dedup_df, text_col):
        ensure_dir(self.index_dir)
        texts = ["" if pd.isna(dedup_df.loc[i, text_col]) else str(dedup_df.loc[i, text_col])
                 for i in range(len(dedup_df))]
        mat = self._encode_texts(texts)
        save_npz(os.path.join(self.index_dir, "docs.npz"), mat)
        with open(os.path.join(self.index_dir, "doc_ids.json"), "w") as f:
            json.dump([str(i) for i in range(len(texts))], f)

    @torch.no_grad()
    def _encode_query(self, q):
        self._load()
        toks = self.tok([q], return_tensors="pt", padding=True, truncation=True,
                        max_length=self.max_length).to(self.device)
        logits = self.model(**toks).logits
        activ = torch.log1p(torch.relu(logits))
        w = activ.max(dim=1).values.squeeze(0).cpu().numpy()
        nz = np.where(w >= self.min_weight)[0]
        return csr_matrix((w[nz].astype(np.float32), nz.astype(np.int32), np.array([0,len(nz)], np.int32)),
                          shape=(1, self.model.config.vocab_size), dtype=np.float32)

    def retrieve(self, queries, topk=10):
        docs = load_npz(os.path.join(self.index_dir, "docs.npz"))
        with open(os.path.join(self.index_dir, "doc_ids.json"), "r") as f:
            doc_ids = json.load(f)
        out = {}
        for qid, q in tqdm(queries.items(), desc="SPLADE search"):
            qv = self._encode_query(q)  # [1,V]
            scores = (docs @ qv.T).toarray().ravel()
            if topk >= len(doc_ids):
                idx = np.argsort(-scores)
            else:
                idx = np.argpartition(scores, -topk)[-topk:]
                idx = idx[np.argsort(-scores[idx])]
            out[qid] = [(doc_ids[i], float(scores[i])) for i in idx[:topk]]
        return out

# -----------------------
# Pipeline
# -----------------------

def run_pipeline(df,
                 output_dir="./outputs",
                 bm25_index_dir="./indices/bm25",
                 splade_index_dir="./indices/splade",
                 work_dir="./work"):
    # Basic column handling
    if "question" not in df.columns:
        raise ValueError("Missing 'question' column.")
    text_col = "passage" if "passage" in df.columns else ("context" if "context" in df.columns else None)
    if text_col is None:
        raise ValueError("Missing 'passage' or 'context' column.")
    if "answer" not in df.columns:
        df["answer"] = ""
    df = df.reset_index(drop=True)

    # Deduplicate passages
    dedup_df, rel_map = dedupe_passages(df, text_col)
    corpus_lookup = {str(i): ("" if pd.isna(dedup_df.loc[i, text_col]) else str(dedup_df.loc[i, text_col]))
                     for i in range(len(dedup_df))}
    queries = {str(i): str(df.loc[i, "question"]) for i in range(len(df))}

    ensure_dir(output_dir); ensure_dir(os.path.dirname(bm25_index_dir)); ensure_dir(os.path.dirname(splade_index_dir)); ensure_dir(work_dir)

    # BM25
    bm25 = BM25Retriever(index_dir=bm25_index_dir)
    bm25.build_index(dedup_df, text_col, work_dir=os.path.join(work_dir, "bm25"))
    bm25_res = bm25.retrieve(queries, topk=10)
    save_results("bm25", df, text_col, bm25_res, rel_map, corpus_lookup, os.path.join(output_dir, "bm25_results.csv"))

    # SPLADE
    splade = SPLADERetriever(index_dir=splade_index_dir)
    splade.build_index(dedup_df, text_col)
    splade_res = splade.retrieve(queries, topk=10)
    save_results("splade", df, text_col, splade_res, rel_map, corpus_lookup, os.path.join(output_dir, "splade_results.csv"))

# def save_results(name, df, text_col, ret, rel_map, corpus_lookup, out_csv):
#     rows = []
#     ks = [3,5,10]
#     for i in range(len(df)):
#         qid = str(i)
#         retrieved = ret.get(qid, [])
#         ret_ids = [d for d,_ in retrieved]
#         rel = rel_map[qid]
#         metrics = {}
#         for k in ks:
#             metrics[f"MAP@{k}"] = ap_at_k(ret_ids, rel, k)
#             metrics[f"NDCG@{k}"] = ndcg_at_k(ret_ids, rel, k)
#         pack = [{"doc_id": d, "score": round(float(s),4), "snippet": corpus_lookup.get(d,"")[:200].replace("\n"," ")}
#                 for d,s in retrieved]
#         rows.append({
#             "question": str(df.loc[i, "question"]),
#             "answer": "" if "answer" not in df.columns or pd.isna(df.loc[i, "answer"]) else str(df.loc[i, "answer"]),
#             text_col: "" if pd.isna(df.loc[i, text_col]) else str(df.loc[i, text_col]),
#             f"{name}_ret_docs": json.dumps(pack, ensure_ascii=False),
#             **metrics
#         })
#     pd.DataFrame(rows).to_csv(out_csv, index=False)
#     print(f"{name} results saved to {out_csv}")

def save_results(name, df, text_col, ret, rel_map, corpus_lookup, out_csv):
    rows = []
    ks = [3, 5, 10]
    for i in range(len(df)):
        qid = str(i)
        retrieved = ret.get(qid, [])
        ret_ids = [d for d, _ in retrieved]
        rel = rel_map[qid]

        metrics = {}
        for k in ks:
            metrics[f"MAP@{k}"] = ap_at_k(ret_ids, rel, k)
            metrics[f"NDCG@{k}"] = ndcg_at_k(ret_ids, rel, k)

        # Store complete doc text + short preview for readability
        packed = []
        for d, s in retrieved:
            full = corpus_lookup.get(d, "")
            preview = full[:200].replace("\n", " ")
            packed.append({
                "doc_id": d,
                "score": round(float(s), 4),
                # keep snippet for backward compatibility (optional)
                "snippet": preview,
                # new field with the full text (no truncation)
                "full_text": full,
                # explicit preview field (optional)
                "preview_snippet": preview
            })

        rows.append({
            "question": "" if pd.isna(df.loc[i, "question"]) else str(df.loc[i, "question"]),
            "answer": "" if "answer" not in df.columns or pd.isna(df.loc[i, "answer"]) else str(df.loc[i, "answer"]),
            text_col: "" if pd.isna(df.loc[i, text_col]) else str(df.loc[i, text_col]),
            f"{name}_ret_docs": json.dumps(packed, ensure_ascii=False),
            **metrics
        })

    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print(f"{name} results saved to {out_csv}")

# Example usage
if __name__ == "__main__":
    
    df = pd.read_parquet("hf://datasets/PatronusAI/HaluBench/data/test-00000-of-00001.parquet")
    run_pipeline(df)

  from .autonotebook import tqdm as notebook_tqdm
Oct 16, 2025 5:57:33 PM org.apache.lucene.store.MemorySegmentIndexInputProvider <init>
INFO: Using MemorySegmentIndexInput with Java 21; to disable start with -Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false
BM25 search: 100%|██████████| 14900/14900 [00:12<00:00, 1226.23it/s]


bm25 results saved to ./outputs/bm25_results.csv


BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
SPLADE encode: 100%|██████████| 1553/1553 [00:38<00:00, 40.21it/s]
SPLADE search: 100%|██████████| 14900/14900 [03:20<00:00, 74.45it/s] 


splade results saved to ./outputs/splade_results.csv


In [2]:
for name, path in [("BM25","./outputs/bm25_results.csv"), ("SPLADE","./outputs/splade_results.csv")]:
    df = pd.read_csv(path)
    print(name, "Results:")
    for k in (3,5,10):
        print(f"  MAP@{k}: {pd.to_numeric(df[f'MAP@{k}'], errors='coerce').mean():.4f}, "
              f"NDCG@{k}: {pd.to_numeric(df[f'NDCG@{k}'], errors='coerce').mean():.4f}")
    print()

BM25 Results:
  MAP@3: 0.8234, NDCG@3: 0.8319
  MAP@5: 0.8283, NDCG@5: 0.8407
  MAP@10: 0.8319, NDCG@10: 0.8495

SPLADE Results:
  MAP@3: 0.8305, NDCG@3: 0.8379
  MAP@5: 0.8346, NDCG@5: 0.8452
  MAP@10: 0.8381, NDCG@10: 0.8537



## Testing Only BM25- Halubench

In [2]:
import os
import json
import math
import subprocess
from typing import Dict, List, Tuple

import pandas as pd
from tqdm import tqdm
from pyserini.search.lucene import LuceneSearcher   # <-- updated import


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def dedupe_passages(df: pd.DataFrame, text_col: str) -> Tuple[pd.DataFrame, Dict[str, str]]:
    dedup = df[[text_col]].fillna("").drop_duplicates().reset_index(drop=True)
    text_to_docid = {dedup.loc[i, text_col]: str(i) for i in range(len(dedup))}
    relevant = {}
    for i, row in df.iterrows():
        passage = "" if pd.isna(row[text_col]) else str(row[text_col])
        relevant[str(i)] = text_to_docid[passage]
    return dedup, relevant

def ap_at_k(retrieved_ids: List[str], relevant_id: str, k: int) -> float:
    for rank, doc_id in enumerate(retrieved_ids[:k], start=1):
        if doc_id == relevant_id:
            return 1.0 / rank
    return 0.0

def ndcg_at_k(retrieved_ids: List[str], relevant_id: str, k: int) -> float:
    for rank, doc_id in enumerate(retrieved_ids[:k], start=1):
        if doc_id == relevant_id:
            return 1.0 / math.log2(rank + 1)
    return 0.0

def run_bm25_pipeline(
    df: pd.DataFrame,
    text_col: str = "passage",
    question_col: str = "question",
    answer_col: str = "answer",
    index_dir: str = "./bm25_index",
    output_csv: str = "./bm25_results.csv",
    tmp_dir: str = "./bm25_tmp",
    topk: int = 10,
):
    if question_col not in df.columns:
        raise ValueError(f"Missing '{question_col}' column")
    if text_col not in df.columns:
        raise ValueError(f"Missing '{text_col}' column")
    if answer_col not in df.columns:
        df[answer_col] = ""

    df = df.reset_index(drop=True)

    corpus_df, relevant = dedupe_passages(df, text_col)
    doc_lookup = {str(i): corpus_df.loc[i, text_col] for i in range(len(corpus_df))}

    ensure_dir(tmp_dir)
    corpus_path = os.path.join(tmp_dir, "docs.jsonl")
    with open(corpus_path, "w", encoding="utf-8") as f:
        for i, row in corpus_df.iterrows():
            f.write(json.dumps({"id": str(i), "contents": row[text_col]}, ensure_ascii=False) + "\n")

    ensure_dir(index_dir)
    cmd = [
        "python", "-m", "pyserini.index.lucene",
        "--collection", "JsonCollection",
        "--input", tmp_dir,
        "--index", index_dir,
        "--generator", "DefaultLuceneDocumentGenerator",
        "--threads", "4",
        "--storePositions", "--storeDocvectors", "--storeRaw"
    ]
    res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    if res.returncode != 0:
        print(res.stdout)
        print(res.stderr)
        raise RuntimeError("Indexing failed (check Java + Pyserini installation).")

    searcher = LuceneSearcher(index_dir)   # <-- updated class
    searcher.set_bm25()

    queries = {str(i): str(df.loc[i, question_col]) for i in range(len(df))}
    all_hits = {}
    for qid, query in tqdm(queries.items(), desc="BM25 search"):
        hits = searcher.search(query, k=topk)
        all_hits[qid] = [(hit.docid, float(hit.score)) for hit in hits]

    rows = []
    for i, row in df.iterrows():
        qid = str(i)
        gold_doc = relevant[qid]
        ret = all_hits.get(qid, [])
        ret_ids = [doc_id for doc_id, _ in ret]

        metrics = {}
        for k in (3, 5, 10):
            metrics[f"MAP@{k}"] = ap_at_k(ret_ids, gold_doc, k)
            metrics[f"NDCG@{k}"] = ndcg_at_k(ret_ids, gold_doc, k)

        rows.append({
            question_col: str(row[question_col]),
            answer_col: "" if pd.isna(row[answer_col]) else str(row[answer_col]),
            text_col: "" if pd.isna(row[text_col]) else str(row[text_col]),
            "bm25_retrieved": json.dumps(
                [
                    {
                        "doc_id": doc_id,
                        "score": round(score, 4),
                        "passage": doc_lookup.get(doc_id, "")[:200].replace("\n", " ")
                    }
                    for doc_id, score in ret
                ],
                ensure_ascii=False
            ),
            **metrics
        })

    pd.DataFrame(rows).to_csv(output_csv, index=False)
    print(f"Done. Results saved to {output_csv}")

def print_bm25_overall(csv_path: str, ks=(3, 5, 10)):
    df = pd.read_csv(csv_path)

    print("BM25 Results:")
    for k in ks:
        map_mean = pd.to_numeric(df[f"MAP@{k}"], errors="coerce").mean()
        ndcg_mean = pd.to_numeric(df[f"NDCG@{k}"], errors="coerce").mean()
        print(f"  MAP@{k}: {map_mean:.4f}, NDCG@{k}: {ndcg_mean:.4f}")
    print()

if __name__ == "__main__":
    # Example:

    from datasets import load_dataset
    ds = load_dataset("PatronusAI/HaluBench")
    df = ds["test"].to_pandas()
    run_bm25_pipeline(df, text_col="passage")
    print_bm25_overall("./bm25_results.csv")
  
    pass

  from .autonotebook import tqdm as notebook_tqdm
Oct 13, 2025 8:52:00 AM org.apache.lucene.store.MemorySegmentIndexInputProvider <init>
INFO: Using MemorySegmentIndexInput with Java 21; to disable start with -Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false
BM25 search: 100%|██████████| 14900/14900 [00:10<00:00, 1417.17it/s]


Done. Results saved to ./bm25_results.csv


In [3]:
def print_bm25_overall(csv_path: str, ks=(3, 5, 10)):
    df = pd.read_csv(csv_path)

    print("BM25 Results:")
    for k in ks:
        map_mean = pd.to_numeric(df[f"MAP@{k}"], errors="coerce").mean()
        ndcg_mean = pd.to_numeric(df[f"NDCG@{k}"], errors="coerce").mean()
        print(f"  MAP@{k}: {map_mean:.4f}, NDCG@{k}: {ndcg_mean:.4f}")
    print()

if __name__ == "__main__":
    print_bm25_overall("./bm25_results.csv")

BM25 Results:
  MAP@3: 0.8234, NDCG@3: 0.8319
  MAP@5: 0.8283, NDCG@5: 0.8407
  MAP@10: 0.8319, NDCG@10: 0.8495



# Dense Retrieval on HaluBench (ST and Contriever)

In [1]:
'''
Treat each passage as a single document (no split). Deduplicate passages globally before indexing.
Use two dense models:
all-mpnet-base-v2 (sentence-level encoder)
facebook/contriever (unsupervised dense retriever)
Use FAISS to build vector indices.
Evaluate by comparing retrieved docs to the original passage of each question (after mapping to dedup doc IDs).
Save results to CSV (one per model) with the same pattern and report MAP/NDCG@3/5/10.
Design choices and reasoning:

Embedding models:
MPNet (sentence-transformers/all-mpnet-base-v2) is designed for sentence-level semantic similarity. We will load via AutoModel and use standard mean pooling over the last hidden states with attention mask, then L2-normalize.
Contriever (facebook/contriever) is explicitly built for dense retrieval. We’ll also use mean pooling + L2 normalization for compatibility and cosine-like search.
We are not using AutoModelForMaskedLM here because MLM heads are for token prediction, not sentence embeddings.
Text normalization:
We normalize passages and queries consistently: strip, collapse whitespace, lowercase. This helps dedup and matching.
Relevance is determined by exact match to the dedupbed passage string, which avoids fuzzy alignment errors.
FAISS index:
Use IndexFlatIP with L2-normalized embeddings, equivalent to cosine similarity search. It’s simplest and accurate for these encoders.
'''

''

In [3]:
import os
import json
import math
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel

import faiss


# -----------------------
# Config
# -----------------------

MPNET_NAME = "sentence-transformers/all-mpnet-base-v2"
CONTRIEVER_NAME = "facebook/contriever"
TOPK = 10


# -----------------------
# Utils
# -----------------------

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def normalize_text(s: str) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s)
    s = s.replace("\u00a0", " ")
    s = " ".join(s.split())  # collapse whitespace
    return s.strip().lower()

def build_passage_dedup(df: pd.DataFrame, passage_col: str = "passage"):
    """
    Build a deduplicated corpus of passages (each passage is one document).
    Returns:
      - corpus: list[str] (unique normalized passages)
      - docid_lookup: dict[int->str] mapping doc_id to passage text
      - rel_map: dict[row_idx_str -> doc_id_str] mapping each question to its relevant doc id
    """
    # Normalize passage strings
    passages = [normalize_text(p) for p in df[passage_col].tolist()]
    # Deduplicate with order preserved
    seen = {}
    corpus = []
    for p in passages:
        if p not in seen:
            seen[p] = len(corpus)
            corpus.append(p)
    # Build relevance map: each row maps to the doc_id of its (normalized) passage
    rel_map = {}
    for i, p in enumerate(passages):
        rel_map[str(i)] = str(seen[p])
    docid_lookup = {str(i): corpus[i] for i in range(len(corpus))}
    return corpus, docid_lookup, rel_map

def mean_pool(last_hidden_state, attention_mask):
    # last_hidden_state: [B, L, H], attention_mask: [B, L]
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)  # [B,L,1]
    sum_embeddings = (last_hidden_state * mask).sum(dim=1)          # [B,H]
    lengths = mask.sum(dim=1).clamp(min=1e-9)                       # [B,1]
    return sum_embeddings / lengths

@torch.no_grad()
def encode_texts(texts, tokenizer, model, device="cuda" if torch.cuda.is_available() else "cpu", batch_size=64, max_length=256):
    model.eval().to(device)
    all_vecs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Encoding"):
        batch = texts[i:i+batch_size]
        toks = tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        toks = {k: v.to(device) for k, v in toks.items()}
        outputs = model(**toks)
        # Use last_hidden_state and mean pooling with attention mask
        emb = mean_pool(outputs.last_hidden_state, toks["attention_mask"])
        # L2 normalize
        emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        all_vecs.append(emb.cpu())
    return torch.cat(all_vecs, dim=0).numpy().astype("float32")  # [N, D]

def build_faiss_index(embs: np.ndarray, use_gpu=False):
    """
    Build a FAISS IndexFlatIP and add embeddings.
    """
    d = embs.shape[1]
    index = faiss.IndexFlatIP(d)
    # If GPU desired and available:
    if use_gpu and faiss.get_num_gpus() > 0:
        res = faiss.StandardGpuResources()
        index = faiss.index_cpu_to_gpu(res, 0, index)
    index.add(embs)  # add vectors
    return index

def search_faiss(index, query_vecs: np.ndarray, topk: int = 10):
    # FAISS expects float32
    q = query_vecs.astype("float32")
    scores, idxs = index.search(q, topk)  # [Q, topk]
    return scores, idxs

def ap_at_k(ret_ids, rel_id, k):
    for rank, did in enumerate(ret_ids[:k], 1):
        if did == rel_id:
            return 1.0 / rank
    return 0.0

def ndcg_at_k(ret_ids, rel_id, k):
    for rank, did in enumerate(ret_ids[:k], 1):
        if did == rel_id:
            return 1.0 / math.log2(rank + 1)
    return 0.0


# -----------------------
# Model wrappers
# -----------------------

class DenseRetriever:
    def __init__(self, model_name, batch_size=64, max_length=256, device=None):
        self.model_name = model_name
        self.batch_size = batch_size
        self.max_length = max_length
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tok = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def encode_corpus(self, corpus_texts):
        return encode_texts(corpus_texts, self.tok, self.model, self.device, self.batch_size, self.max_length)

    def encode_queries(self, queries):
        return encode_texts(queries, self.tok, self.model, self.device, self.batch_size, self.max_length)


# -----------------------
# Pipeline
# -----------------------

def run_pipeline_halubench(
    parquet_path="hf://datasets/PatronusAI/HaluBench/data/test-00000-of-00001.parquet",
    output_dir="./outputs_dense",
    topk=TOPK,
    batch_size=64,
    use_gpu_faiss=False
):
    ensure_dir(output_dir)

    # 1) Load data
    df = pd.read_parquet(parquet_path)
    # Expect columns: id, question, answer, passage
    for col in ["id", "question", "answer", "passage"]:
        if col not in df.columns:
            raise ValueError(f"Missing required column '{col}' in dataframe")

    # 2) Normalize and deduplicate passages (each passage is one document)
    corpus, docid_lookup, rel_map = build_passage_dedup(df, passage_col="passage")

    # 3) Prepare queries (normalize questions for consistency)
    queries = {str(i): normalize_text(df.loc[i, "question"]) for i in range(len(df))}
    query_list = [queries[str(i)] for i in range(len(df))]

    # 4) Create two dense retrievers: MPNet and Contriever
    retrievers = {
        "mpnet": DenseRetriever(MPNET_NAME, batch_size=batch_size),
        "contriever": DenseRetriever(CONTRIEVER_NAME, batch_size=batch_size),
    }

    for name, retr in retrievers.items():
        print(f"\n=== Building FAISS index for {name} ===")
        # 5) Encode corpus and build FAISS
        corpus_embs = retr.encode_corpus(corpus)  # [N, D], L2-normalized
        index = build_faiss_index(corpus_embs, use_gpu=use_gpu_faiss)

        # 6) Encode queries and retrieve
        query_embs = retr.encode_queries(query_list)  # [Q, D], L2-normalized
        scores, idxs = search_faiss(index, query_embs, topk=topk)

        # 7) Build retrieval results in the same structure as before
        # Map FAISS ids to our doc_id strings
        out_res = {}
        for i in range(len(df)):
            ret = []
            for r in range(topk):
                did_int = int(idxs[i, r])
                if did_int < 0 or did_int >= len(corpus):
                    continue
                did = str(did_int)
                sc = float(scores[i, r])
                ret.append((did, sc))
            out_res[str(i)] = ret

        # 8) Save results and compute MAP/NDCG against the single relevant doc per question
        save_results_dense(
            name=name,
            df=df,
            ret=out_res,
            rel_map=rel_map,           # single relevant doc id per row (string)
            corpus_lookup=docid_lookup,
            out_csv=os.path.join(output_dir, f"{name}_results.csv")
        )

        print(f"{name}: results saved to {os.path.join(output_dir, f'{name}_results.csv')}")


def save_results_dense(name, df, ret, rel_map, corpus_lookup, out_csv):
    rows = []
    ks = [3, 5, 10]
    for i in range(len(df)):
        qid = str(i)
        retrieved = ret.get(qid, [])
        ret_ids = [d for d, _ in retrieved]
        rel = rel_map[qid]  # single relevant doc id for this question

        metrics = {}
        for k in ks:
            metrics[f"MAP@{k}"] = ap_at_k(ret_ids, rel, k)
            metrics[f"NDCG@{k}"] = ndcg_at_k(ret_ids, rel, k)

        # Store full passage text + short preview for readability
        pack = []
        for d, s in retrieved:
            full = corpus_lookup.get(d, "")
            preview = full[:300].replace("\n", " ")
            pack.append({
                "doc_id": d,
                "score": round(float(s), 6),
                # full content for robust evaluation
                "full_text": full,
                # backward-compatible alias (optional)
                "snippet": preview,
                # # explicit preview field (optional)
                # "passage": preview
            })

        rows.append({
            "id": df.loc[i, "id"],
            "question": normalize_text(df.loc[i, "question"]),
            "answer": normalize_text(df.loc[i, "answer"]),
            # Store the original (normalized) passage used as the single ground truth
            "passage": normalize_text(df.loc[i, "passage"]),
            f"{name}_ret_docs": json.dumps(pack, ensure_ascii=False),
            **metrics
        })

    pd.DataFrame(rows).to_csv(out_csv, index=False)


# -----------------------
# Run
# -----------------------

if __name__ == "__main__":
    # This loads HaluBench and runs both dense retrievers with FAISS
    run_pipeline_halubench(
        parquet_path="hf://datasets/PatronusAI/HaluBench/data/test-00000-of-00001.parquet",
        output_dir="./outputs_dense",
        topk=10,
        batch_size=64,
        use_gpu_faiss=False  # set True if you have faiss-gpu installed and a GPU available
    )

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



=== Building FAISS index for mpnet ===


Encoding: 100%|██████████| 195/195 [00:36<00:00,  5.39it/s]
Encoding: 100%|██████████| 233/233 [00:14<00:00, 15.60it/s]


mpnet: results saved to ./outputs_dense/mpnet_results.csv

=== Building FAISS index for contriever ===


Encoding: 100%|██████████| 195/195 [00:32<00:00,  6.08it/s]
Encoding: 100%|██████████| 233/233 [00:14<00:00, 16.11it/s]


contriever: results saved to ./outputs_dense/contriever_results.csv


In [4]:
import pandas as pd, os
for name in ["mpnet", "contriever"]:
    df = pd.read_csv(os.path.join("outputs_dense", f"{name}_results.csv"))
    print(name.capitalize(), "Results:")
    for k in (3,5,10):
        print(f"  MAP@{k}: {pd.to_numeric(df[f'MAP@{k}'], errors='coerce').mean():.4f}, "
              f"NDCG@{k}: {pd.to_numeric(df[f'NDCG@{k}'], errors='coerce').mean():.4f}")
    print()

Mpnet Results:
  MAP@3: 0.7672, NDCG@3: 0.7779
  MAP@5: 0.7740, NDCG@5: 0.7903
  MAP@10: 0.7793, NDCG@10: 0.8028

Contriever Results:
  MAP@3: 0.7477, NDCG@3: 0.7570
  MAP@5: 0.7524, NDCG@5: 0.7656
  MAP@10: 0.7562, NDCG@10: 0.7746

