In [1]:
import json
import os
from typing import Dict, List, Tuple

def load_triviaqa(dataset_path: str) -> Tuple[Dict[str, str], List[Dict]]:
    """
    Load TriviaQA dataset and return documents and QA pairs.
    
    Args:
        dataset_path: Path to the TriviaQA JSON file (e.g., 'unfiltered-dev.json')
        max_docs: Maximum number of documents to load
        max_qa_pairs: Maximum number of QA pairs to load
        
    Returns:
        Tuple of (documents dictionary, list of QA pairs)
    """
    print(f"Loading TriviaQA dataset from: {dataset_path}")
    
    # Load the TriviaQA dataset
    with open(dataset_path, 'r', encoding='utf-8') as f:
        data = json.load(f)['Data']
    
    # Extract documents and QA pairs
    docs = {}
    qa_pairs = []
    
    for i, item in enumerate(data):
        if len(docs)>100:
            break
        question = item['Question']
        answers = item['Answer']['NormalizedAliases']
        
        # Use the first answer as the gold answer
        gold_answer = answers if answers else list(item['Answer']['NormalizedValue'])
        
        qa_pairs.append({
            "question": question,
            "answer": gold_answer
        })
        
        # Extract documents from Wikipedia sources
        if 'EntityPages' in item:
            for j, doc_info in enumerate(item['EntityPages']):
                    
                doc_id = f"wiki_{doc_info['Title']}_{j}"
                
                # Check if file exists
                if 'Filename' in doc_info:
                    doc_path = os.path.join(os.path.dirname("data/wikipedia"), doc_info['Filename'])
                    if os.path.exists(doc_path):
                        try:
                            with open(doc_path, 'r', encoding='utf-8') as doc_file:
                                doc_content = doc_file.read()
                                docs[doc_id] = doc_content

                        except:
                            # If file can't be read, use the snippet
                            docs[doc_id] = doc_info.get('Snippet', '')
                    else:
                        # If file doesn't exist, use the snippet
                        docs[doc_id] = doc_info.get('Snippet', '')
                else:
                    # If no filename, use the snippet
                    docs[doc_id] = doc_info.get('Snippet', '')
                    
        # Add web documents if available
        if 'SearchResults' in item:
            for j, doc_info in enumerate(item['SearchResults']):
                    
                doc_id = f"web_{j}_{doc_info.get('Title', '')}"
                
                if 'Filename' in doc_info:
                    doc_path = os.path.join("data/web", doc_info['Filename'])
                    if os.path.exists(doc_path):
                        try:
                            with open(doc_path, 'r', encoding='utf-8') as doc_file:
                                doc_content = doc_file.read()
                                docs[doc_id] = doc_content
                        except:
                            docs[doc_id] = doc_info.get('Snippet', '')
                    else:
                        docs[doc_id] = doc_info.get('Snippet', '')
                else:
                    docs[doc_id] = doc_info.get('Snippet', '')
    
    print(f"Loaded {len(docs)} documents and {len(qa_pairs)} QA pairs from TriviaQA")
    return docs, qa_pairs

documents_dict, qa_pairs = load_triviaqa("triviaqa-unfiltered/unfiltered-web-dev.json")
print(documents_dict)

Loading TriviaQA dataset from: triviaqa-unfiltered/unfiltered-web-dev.json
Loaded 101 documents and 2 QA pairs from TriviaQA
{'web_0_The Man Behind the Mask (Chipmunk Version) - YouTube': '', 'web_1_Alvin & The Chipmunks - Behind The Voice Actors - Images ...': '', 'web_2_The Easter Chipmunk - Cast Images | Behind The Voice Actors': '', 'web_3_Alvin Seville - Alvin and the Chipmunks Wiki - Wikia': '', 'web_4_Alvin and the Chipmunks (2007) - IMDb': 'Alvin and the Chipmunks (2007) - IMDb\nIMDb\n17 January 2017 4:34 PM, UTC\nNEWS\nThere was an error trying to load your rating for this title.\nSome parts of this page won\'t work property. Please reload or try later.\nX Beta I\'m Watching This!\nKeep track of everything you watch; tell your friends.\nError\nAlvin and the Chipmunks\xa0( 2007 )\nPG |\nA struggling songwriter named Dave Seville finds success when he comes across a trio of singing chipmunks: mischievous leader Alvin, brainy Simon, and chubby, impressionable Theodore.\nDirector:

In [25]:
import time

def benchmark_retrieval(retriever, query, top_k):
    start = time.time()
    results = retriever.retrieve(query, top_k=top_k)
    end = time.time()
    latency = end - start
    return results, latency

def match_exists(retrieved_docs, answers):
    print()
    for i in range(len(answers)):
        found = any(any(ans.lower() in doc.lower() for ans in answers[i]) for doc in retrieved_docs[i][:])
        if found:
            return 1
    return 0

def benchmark_all(strategies, documents, qas, top_k=5):
    results = []
    overall_retriever_results = []


    for chunker_name, chunker in strategies["chunkers"].items():
        chunks = []

        #Chunking Latency
        chunk_start_time = time.time()
        for doc in documents:
            chunks.extend(chunker.chunk(doc))
        chunk_end_time = time.time()
        chunking_latency = chunk_end_time - chunk_start_time

        # Retriever Latency
        for retriever_name, retriever_class in strategies["retrievers"].items():
            retriever_initialize_start_time = time.time()
            retriever = retriever_class(chunks)
            retriever_initialize_end_time = time.time()
            retriever_initialize_latency = retriever_initialize_end_time - retriever_initialize_start_time
            total_retrieve_latency = 0

            for qa in qas:
                query, answers = qa["question"], qa["answer"]
                retrieved_docs, latency = benchmark_retrieval(retriever, query, top_k)
                does_match = match_exists(retrieved_docs, answers)
                overall_retriever_results.append({
                    "query": query,
                    "chunker": chunker_name,
                    "retriever": retriever_name,
                    "latency": latency,
                    "results": retrieved_docs,
                    "answers": answers,
                    "does_match": does_match
                })
                total_retrieve_latency += latency

            overall_retriever_results.append({
                "retriever": retriever_name,
                "chunker": chunker_name,
                "initialize_latency": retriever_initialize_latency,
                "retrieve_latency": total_retrieve_latency/len(qas),
                "chunking_latency": chunking_latency,
                "accuracy" : sum([1 for res in overall_retriever_results if res["does_match"]]) / len(qas),
            })


    return results, overall_retriever_results



In [None]:
from chunkers import FixedChunker, OverlappingChunker, SemanticChunker
from retrievers.bm25_retriever import BM25Retriever
from retrievers.dense_retriever import DPRRetriever
from retrievers.colbert_retriever import ColBERTRetriever
from retrievers.hybrid_retriever import HybridRetriever

documents = list(documents_dict.values())
top_k = 5

strategies = {
    "chunkers": {
        # "fixed": FixedChunker(chunk_size=20, drop_last=False),
        "overlapping": OverlappingChunker(chunk_size=24, overlap=8, drop_last=False),
        # "semantic": SemanticChunker(chunk_char_limit=120)
    },
    "retrievers": {
        # "bm25": BM25Retriever,
        "dpr": DPRRetriever,
        # "colbert": ColBERTRetriever,
        # "hybrid": HybridRetriever
    }
}
results = []
query, answers = "Who saved Barry's life", ["Vanessa"]
qa_pairs = []
qa_pairs.append({
            "question": query,
            "answer": answers
        })

results = benchmark_all(strategies, documents, qa_pairs, top_k)




# Print results
for res in results:
    print("Result", res)

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the





[{'question': 'Who was the man behind The Chipmunks?', 'answer': 'david seville'}]


In [1]:
# --- Example text ----------------------------------------------------------
doc = (
    "The sky appears blue because molecules in the air scatter blue light from "
    "the sun more than they scatter red light. This phenomenon is called Rayleigh "
    "scattering. When the sun is lower in the sky, the light has to pass through "
    "more atmosphere, so more blue and green light is scattered away, leaving the "
    "reds and oranges we see at sunrise and sunset."
)

# --- Import the chunkers ---------------------------------------------------
from chunkers import FixedChunker, OverlappingChunker, SemanticChunker

# ---------------------------------------------------------------------------
# 1. Fixed-length chunking: split every 20 tokens (drop tail if < 20 tokens)
fixed_chunker = FixedChunker(chunk_size=20, drop_last=True)
fixed_chunks = fixed_chunker.chunk(doc)

# 2. Overlapping windows: 24-token windows with 8-token overlap
overlap_chunker = OverlappingChunker(chunk_size=24, overlap=8, drop_last=False)
overlap_chunks = overlap_chunker.chunk(doc)

# 3. Semantic packing: keep whole sentences, ~120 characters per chunk
semantic_chunker = SemanticChunker(chunk_char_limit=120)
semantic_chunks = semantic_chunker.chunk(doc)

# --- Inspect the output ----------------------------------------------------
print("Fixed-length chunks (20 tokens each):")
for i, c in enumerate(fixed_chunks, 1):
    print(f"  {i:02d}. {c!r}")

print("\nOverlapping chunks (24 tokens, 8-token overlap):")
for i, c in enumerate(overlap_chunks, 1):
    print(f"  {i:02d}. {c!r}")

print("\nSemantic chunks (~120 chars, sentence-aware):")
for i, c in enumerate(semantic_chunks, 1):
    print(f"  {i:02d}. {c!r}")


Fixed-length chunks (20 tokens each):
  01. 'The sky appears blue because molecules in the air scatter blue light from the sun more than they scatter red'
  02. 'light. This phenomenon is called Rayleigh scattering. When the sun is lower in the sky, the light has to pass'
  03. 'through more atmosphere, so more blue and green light is scattered away, leaving the reds and oranges we see at'

Overlapping chunks (24 tokens, 8-token overlap):
  01. 'The sky appears blue because molecules in the air scatter blue light from the sun more than they scatter red light. This phenomenon is'
  02. 'than they scatter red light. This phenomenon is called Rayleigh scattering. When the sun is lower in the sky, the light has to pass'
  03. 'in the sky, the light has to pass through more atmosphere, so more blue and green light is scattered away, leaving the reds and'
  04. 'light is scattered away, leaving the reds and oranges we see at sunrise and sunset.'

Semantic chunks (~120 chars, sentence-aware):

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 0.  Imports & one-time setup
# ──────────────────────────────────────────────────────────────────────────────
from pathlib import Path
from typing import Dict, List
from collections import defaultdict
import json, time

# local modules
from chunkers import FixedChunker, OverlappingChunker, SemanticChunker
from retrievers import bm25_retriever, dense_retriever, hybrid_retriever
from retrievers.dense_retriever import DPRRetriever  # example alias
from utils.evaluation import exact_match, f1_score            # already in your repo
from utils.timing import Timer                                 # already in your repo

# external (install if missing)
#   pip install rank-bm25 sentence-transformers faiss-cpu
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer

# ──────────────────────────────────────────────────────────────────────────────
# 1.  Load a toy corpus
# ──────────────────────────────────────────────────────────────────────────────
DOCS: Dict[str, str] = {
    "doc1": Path("data/doc1.txt").read_text(encoding="utf-8"),
    "doc2": Path("data/doc2.txt").read_text(encoding="utf-8"),
    "doc3": Path("data/doc3.txt").read_text(encoding="utf-8"),
}
print(f"Loaded {len(DOCS)} raw documents")

# ──────────────────────────────────────────────────────────────────────────────
# 2.  Choose **one** chunker recipe for this run
#     (Swap these objects in a loop if you’re doing a grid-search)
# ──────────────────────────────────────────────────────────────────────────────
chunker = FixedChunker(chunk_size=128, drop_last=False)
# chunker = OverlappingChunker(chunk_size=256, overlap=64)
# chunker = SemanticChunker(chunk_char_limit=1500)

# 2-b.  Chunk every document → corpus_chunks[id] = [chunk0, …]
corpus_chunks: Dict[str, List[str]] = {
    doc_id: chunker.chunk(txt) for doc_id, txt in DOCS.items()
}
flat_chunks = [c for chs in corpus_chunks.values() for c in chs]
print(
    f"Chunked into {len(flat_chunks):,} total passages "
    f"(avg {len(flat_chunks)/len(DOCS):.1f} per doc)"
)

# ──────────────────────────────────────────────────────────────────────────────
# 3.  Build the four retrieval back-ends on **exactly the same chunk set**
#     The dense & hybrid examples assume a SentenceTransformer-based DPR.
# ──────────────────────────────────────────────────────────────────────────────
# 3-a. BM25 (sparse)
bm25_index = BM25Okapi([c.split() for c in flat_chunks])

# 3-b. Dense Passage Retriever
dpr_model = SentenceTransformer("facebook-dpr-ctx_encoder-multiset-base")
dpr_index = dense_retriever.build_faiss_index(flat_chunks, dpr_model)  # your util

# 3-c. ColBERT (late interaction)  ← placeholder call
colbert_retriever = dense_retriever.ColBERTIndexer().fit(flat_chunks)

# 3-d. Hybrid (= BM25 + DPR scores)
hybrid = hybrid_retriever.Hybrid(bm25_index, dpr_index, alpha=0.4)

# ──────────────────────────────────────────────────────────────────────────────
# 4.  Query loop + simple QA evaluation
# ──────────────────────────────────────────────────────────────────────────────
qa_pairs = json.loads(Path("data/dev_qas.json").read_text())[:50]  # small dev slice
K = 5

results_by_system = defaultdict(list)

for q in qa_pairs:
    query, gold = q["question"], q["answer"]  # gold answer string

    with Timer() as t:
        top_chunks = bm25_index.get_top_n(query.split(), flat_chunks, n=K)
    results_by_system["bm25"].append(
        {"latency_ms": t.ms, "em": exact_match(gold, top_chunks)}
    )

    with Timer() as t:
        top_chunks = dense_retriever.search_faiss(query, dpr_index, dpr_model, top_k=K)
    results_by_system["dpr"].append(
        {"latency_ms": t.ms, "em": exact_match(gold, top_chunks)}
    )

    with Timer() as t:
        top_chunks = colbert_retriever.search(query, top_k=K)
    results_by_system["colbert"].append(
        {"latency_ms": t.ms, "em": exact_match(gold, top_chunks)}
    )

    with Timer() as t:
        top_chunks = hybrid.search(query, top_k=K)
    results_by_system["hybrid"].append(
        {"latency_ms": t.ms, "em": exact_match(gold, top_chunks)}
    )

# ──────────────────────────────────────────────────────────────────────────────
# 5.  Aggregate & display – how *this* chunker impacted each retriever
# ──────────────────────────────────────────────────────────────────────────────
def summarise(rows):
    return {
        "EM": sum(r["em"] for r in rows) / len(rows),
        "Latency (ms)": sum(r["latency_ms"] for r in rows) / len(rows),
    }

summary = {name: summarise(rows) for name, rows in results_by_system.items()}
print("\n===  Chunker:", chunker.__class__.__name__, " ===")
for system, metrics in summary.items():
    print(f"{system:7s} | EM={metrics['EM']:.3f} | Latency≈{metrics['Latency (ms)']:.1f} ms")


In [10]:
#!/usr/bin/env python
"""
Minimal Retrieval-only demo on TriviaQA-unfiltered.

Dependencies:
  pip install rank-bm25 sentence-transformers faiss-cpu tqdm
"""

from __future__ import annotations
import os, json, gzip, time
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
from tqdm import tqdm

# Retrieval libs
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import faiss                           # CPU version is OK

# Your chunkers (put the three modules in rag_pipeline/chunkers/)
from chunkers import FixedChunker, OverlappingChunker, SemanticChunker

# ────────────────────────────────────────────────────────────────────────────
# 1. TriviaQA loader (improved)
# ────────────────────────────────────────────────────────────────────────────
def load_triviaqa(
    dataset_path: str,
    max_docs: int = 300,
    max_qa_pairs: int = 100,
) -> Tuple[Dict[str, str], List[Dict]]:
    """
    Returns:
        docs   : {doc_id: raw_text}
        qa_pairs: [{"question": str, "answer": str}, …]
    """
    print(f"[load] TriviaQA from {dataset_path}")
    with open(dataset_path, "r", encoding="utf-8") as f:
        data = json.load(f)["Data"]

    root = Path(dataset_path).parent
    docs: Dict[str, str] = {}
    qa_pairs: List[Dict] = []

    def read_evidence(file_path: Path) -> str:
        if not file_path.exists():
            return ""
        # Many evidence files are .gz
        try:
            if file_path.suffix == ".gz":
                with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as g:
                    return g.read()
            return file_path.read_text(encoding="utf-8", errors="ignore")
        except Exception:
            return ""

    def add_doc(doc_info, prefix: str, j: int):
        if len(docs) >= max_docs:
            return
        doc_id = f"{prefix}_{j}"
        txt = ""
        if doc_info.get("Filename"):
            txt = read_evidence(root / doc_info["Filename"])
        # Fallbacks
        txt = txt or doc_info.get("Snippet", "") or doc_info.get("Title", "") or doc_info.get("Url", "")
        docs[doc_id] = txt

    for i, item in enumerate(data):
        if i >= max_qa_pairs:
            break

        # QA
        question = item["Question"]
        aliases = item["Answer"].get("NormalizedAliases") or []
        gold = aliases[0] if aliases else item["Answer"]["NormalizedValue"]
        qa_pairs.append({"question": question, "answer": gold})

        # Wiki evidence
        for j, d in enumerate(item.get("EntityPages", [])):
            add_doc(d, f"wiki_{d.get('Title','wiki')}", j)
        # Web search evidence
        for j, d in enumerate(item.get("SearchResults", [])):
            add_doc(d, "web", j)

    # Drop empties
    docs = {k: v for k, v in docs.items() if v.strip()}
    print(f"[load] kept {len(docs)} non-empty docs, {len(qa_pairs)} QA pairs")
    return docs, qa_pairs


# ────────────────────────────────────────────────────────────────────────────
# 2. Exact-Match helper
# ────────────────────────────────────────────────────────────────────────────
def exact_match(gold: str, passages: List[str]) -> int:
    g = gold.lower()
    return int(any(g in p.lower() for p in passages))


# ────────────────────────────────────────────────────────────────────────────
# 3. Main
# ────────────────────────────────────────────────────────────────────────────
def main():
    # -------- paths & params ------------------------------------------------
    TQA_JSON = "data/triviaqa-unfiltered/unfiltered-web-dev.json"   # edit if needed
    MAX_DOCS = 300
    MAX_QA   = 100
    TOP_K    = 5

    # -------- load ----------------------------------------------------------
    docs, qa_pairs = load_triviaqa(TQA_JSON, MAX_DOCS, MAX_QA)
    if not docs:
        raise RuntimeError("No non-empty docs – check dataset path / permissions")

    # -------- choose chunker ------------------------------------------------
    chunker = FixedChunker(chunk_size=128, drop_last=False)
    # chunker = OverlappingChunker(chunk_size=256, overlap=64)
    # chunker = SemanticChunker(chunk_char_limit=1500)

    chunk_texts, chunk_to_doc = [], []
    for doc_id, txt in docs.items():
        for ch in chunker.chunk(txt):
            chunk_texts.append(ch)
            chunk_to_doc.append(doc_id)

    print(f"[chunk] {chunker.__class__.__name__}: {len(chunk_texts)} chunks "
          f"(≈{len(chunk_texts)/len(docs):.1f} per doc)")

    # -------- BM25 ----------------------------------------------------------
    bm25 = BM25Okapi([c.split() for c in chunk_texts])

    # -------- DPR -----------------------------------------------------------
    model = SentenceTransformer("facebook-dpr-ctx_encoder-multiset-base")
    print("[dpr] encoding chunks …")
    ctx_emb = model.encode(
        chunk_texts,
        batch_size=64,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )
    index = faiss.IndexFlatIP(ctx_emb.shape[1])
    index.add(ctx_emb)
    print(f"[dpr] Faiss index: {index.ntotal} vectors")

    # -------- retrieval loop ------------------------------------------------
    bm25_hits = dpr_hits = 0
    bm25_lat, dpr_lat = [], []

    for qa in tqdm(qa_pairs, desc="retrieving"):
        q, gold = qa["question"], qa["answer"]

        # BM25
        t0 = time.perf_counter()
        ids = bm25.get_top_n(q.split(), list(range(len(chunk_texts))), n=TOP_K)
        bm25_lat.append((time.perf_counter() - t0) * 1e3)
        bm25_hits += exact_match(gold, [chunk_texts[i] for i in ids])

        # DPR
        t0 = time.perf_counter()
        q_emb = model.encode([q], normalize_embeddings=True)
        _, idxs = index.search(q_emb, TOP_K)
        dpr_lat.append((time.perf_counter() - t0) * 1e3)
        dpr_hits += exact_match(gold, [chunk_texts[i] for i in idxs[0]])

    # -------- summary -------------------------------------------------------
    n = len(qa_pairs)
    print("\n=== RESULTS (Exact-Match evidence recall) ===")
    print(f"BM25 | EM@{TOP_K}: {bm25_hits/n:.3f} | avg latency {np.mean(bm25_lat):.1f} ms")
    print(f"DPR  | EM@{TOP_K}: {dpr_hits/n:.3f} | avg latency {np.mean(dpr_lat):.1f} ms")


if __name__ == "__main__":
    main()



[load] TriviaQA from data/triviaqa-unfiltered/unfiltered-web-dev.json
[load] kept 250 non-empty docs, 100 QA pairs
[chunk] FixedChunker: 250 chunks (≈1.0 per doc)
[dpr] encoding chunks …


Batches: 100%|██████████| 4/4 [00:01<00:00,  2.08it/s]


[dpr] Faiss index: 250 vectors


retrieving: 100%|██████████| 100/100 [00:04<00:00, 22.58it/s]


=== RESULTS (Exact-Match evidence recall) ===
BM25 | EM@5: 0.010 | avg latency 0.8 ms
DPR  | EM@5: 0.000 | avg latency 43.2 ms



