<a href="https://colab.research.google.com/github/johnycodestone/Optimized-Urdu-RAG-COVID-19/blob/main/Sparse_Retriever_Model_NLP_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1: Install required libraries (run this cell first and one by one all required libraries will be installed)
# - transformers: model + generation
# - sentence-transformers: dense embeddings / fine-tuning helpers
# - faiss-cpu (or faiss-gpu if GPU available)
# - rank_bm25: BM25 baseline
# - datasets: convenient JSONL loading
# - evaluate / sacrebleu: BLEU/chrF metrics
# - tqdm: progress bars
# - accelerate (optional) for distributed/faster training
!pip install -q transformers sentence-transformers faiss-cpu rank_bm25 datasets evaluate sacrebleu tqdm accelerate


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m114.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.8/100.8 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# Connect to google drive if not already connected
# 2. Mount Google Drive
# We need this to load your fine-tuned Dense Retriever and your Corpus file.
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Optional cell
# To add all the required files run
from google.colab import files
uploaded = files.upload()

In [None]:
!pip list # Optional to run this cell: To check which of the libraries/packages have been installed

In [3]:
# Cell 2: Imports and GPU check: Run this cell after the first cell
import os, json, time, math
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

# Transformers / sentence-transformers
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
import sentence_transformers # Import the package itself to access __version__

# FAISS and BM25
import faiss
from rank_bm25 import BM25Okapi

# Datasets and metrics
from datasets import load_dataset, Dataset
import evaluate
import sacrebleu

# Print versions and GPU info
print("transformers:", transformers.__version__)
print("sentence-transformers:", sentence_transformers.__version__)
try:
    import torch
    print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
except Exception as e:
    print("torch not available:", e)


transformers: 4.57.3
sentence-transformers: 5.2.0
torch: 2.9.0+cu126 cuda: True


In [4]:
# Cell 3: Load JSONL/TSV files into Python structures
# There will be a content folder on left side bar, files panel. This is our root
# folder. Inside it create a data folder, if not already present. Upload all files
# there and then run this cell.

DATA_DIR = Path("drive/MyDrive/data")  # change if files are elsewhere

# Create the data directory if it doesn't exist
import os
os.makedirs(DATA_DIR, exist_ok=True)

def load_jsonl(path):
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))
    return items

corpus_clean = load_jsonl(DATA_DIR / "urdu_covid_corpus_clean.jsonl")
passages_min = load_jsonl(DATA_DIR / "urdu_covid_passages_min.jsonl")
# TSV -> list of dicts
passages_tsv = []
with open(DATA_DIR / "urdu_covid_passages.tsv", "r", encoding="utf-8") as f:
    for line in f:
        # Use split(None, 1) to split on the first occurrence of any whitespace
        # This handles cases where the delimiter might be spaces instead of a tab.
        if line.strip(): # Ensure line is not empty after stripping whitespace
            parts = line.rstrip("\n").split(None, 1)
            if len(parts) == 2:
                pid, text = parts
                passages_tsv.append({"id": pid, "text": text})
            else:
                print(f"Skipping malformed line in urdu_covid_passages.tsv: {line.strip()}")

eval_queries = load_jsonl(DATA_DIR / "eval_queries.jsonl")
synthetic_pairs = load_jsonl(DATA_DIR / "synthetic_qa_pairs.jsonl")
hard_negatives = load_jsonl(DATA_DIR / "hard_negatives.jsonl")

print("Loaded:", len(corpus_clean), "corpus_clean; ", len(passages_min), "passages_min; ", len(eval_queries), "eval queries")


Loaded: 60 corpus_clean;  60 passages_min;  100 eval queries


In [5]:
# Cell 4: Validate IDs referenced in eval/synthetic/hard_negatives exist in corpus
# Run this after Cell 3.
passage_ids = {p["id"] for p in passages_min}
missing = []
for q in eval_queries:
    for pid in q.get("positive_ids", []):
        if pid not in passage_ids:
            missing.append(("eval", q["query_id"], pid))
for s in synthetic_pairs:
    if s["positive_id"] not in passage_ids:
        missing.append(("synthetic", s["synthetic_id"], s["positive_id"]))
for h in hard_negatives:
    for pid in h["hard_negatives"]:
        if pid not in passage_ids:
            missing.append(("hardneg", h["query_id"], pid))
print("Missing references (should be zero):", len(missing))
if missing:
    print(missing[:10])


Missing references (should be zero): 0


In [6]:
# Cell 5 (Run after Cell 4): BM25 baseline index (tokenize with simple whitespace; for Urdu this is OK as baseline)
# We'll store tokenized corpus and BM25 object for retrieval.
from nltk.tokenize import word_tokenize
# If nltk not installed, use simple split
try:
    import nltk
    nltk.download('punkt')
    nltk.download('punkt_tab') # Added to resolve LookupError for 'punkt_tab'
    tokenizer = lambda s: word_tokenize(s)
except Exception:
    tokenizer = lambda s: s.split()

corpus_texts = [p["text"] for p in passages_min]
corpus_ids = [p["id"] for p in passages_min]
tokenized_corpus = [tokenizer(t) for t in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)

# Example retrieval function
def bm25_retrieve(query, k=5):
    q_tokens = tokenizer(query)
    scores = bm25.get_scores(q_tokens)
    topk = np.argsort(scores)[::-1][:k]
    return [(corpus_ids[i], corpus_texts[i], float(scores[i])) for i in topk]

# Quick test
print("BM25 top-3 for sample:", bm25_retrieve(eval_queries[0]["query"], k=3))


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


BM25 top-3 for sample: [('p0001', 'کورونا وائرس مرض 2019 (COVID-19) ایک متعدی بیماری ہے جس کی عام علامات میں بخار، کھانسی اور سانس لینے میں دشواری شامل ہیں۔', 5.810063974702894), ('p0024', 'بچوں میں کووڈ-19 عام طور پر ہلکا ہوتا ہے مگر بعض نادر معاملات میں شدید علامات سامنے آ سکتی ہیں؛ بچوں کے لیے مخصوص رہنمائی مختلف ہو سکتی ہے۔', 5.103496839739362), ('p0002', 'کووڈ-19 کی تشخیص کے لیے rRT-PCR سویب ٹیسٹ عام طور پر استعمال ہوتے ہیں اور یہ وائرس کی موجودگی کی تصدیق کرتے ہیں۔', 4.589270107579207)]


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [7]:
# Cell 5b: BM25-only retriever evaluation tool (run after Cell 5)
# Purpose: standalone evaluation harness for the independent BM25 retriever (bm25_retrieve)
# Metrics included (applicable to a retriever-only evaluation):
#   - Recall@1, Recall@5
#   - MRR (Mean Reciprocal Rank)
#   - Precision@k (k=1,5)
#   - Average / median retrieval latency
#   - Optional: match by gold_passage_id or by substring match of gold_answer
# Output:
#   - Per-query JSONL saved to bm25_eval_results.jsonl
#   - Printed summary with all metrics
#
# Requirements (must be available in the session):
#   - bm25_retrieve(query, k) -> list of (passage_id, passage_text, score)
#   - eval_queries: list of dicts with at least a query field and optionally:
#       * "question" or "query" or "q"  (the query text)
#       * "gold_passage_id" (optional) OR "answer"/"gold" (gold text to match)
#
# Usage:
#   - Run this cell after you build the BM25 index (Cell 5).
#   - Optionally pass a different eval list or k values to evaluate subsets.

# Use this evaluator if your eval_queries items contain "positive_ids" and "gold_answer"
import json, time, re, statistics
from typing import List, Dict

OUT_JSONL = "bm25_eval_results.jsonl"
DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def normalize_text(s: str) -> str:
    if s is None: return ""
    s = str(s).strip()
    return re.sub(r"\s+", " ", s)

def get_query_text(item: Dict) -> str:
    return item.get("query") or item.get("question") or item.get("q") or ""

def evaluate_bm25_with_positive_ids(eval_items: List[Dict],
                                    out_jsonl: str = OUT_JSONL,
                                    k: int = DEFAULT_K,
                                    recall_ks = RECALL_KS,
                                    precision_ks = PRECISION_KS):
    per_query = []
    latencies = []
    rr_list = []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in eval_items:
        total += 1
        q = get_query_text(item)
        positive_ids = item.get("positive_ids") or item.get("positive_id") or []
        # normalize to list of strings
        if isinstance(positive_ids, str):
            positive_ids = [positive_ids]
        positive_ids = [str(x) for x in positive_ids]

        gold_text = normalize_text(item.get("gold_answer") or item.get("answer") or "")

        t0 = time.time()
        try:
            hits = bm25_retrieve(q, k=k)   # (id, text, score)
        except Exception as e:
            hits = []
            print(f"[eval] bm25_retrieve error for query {q[:60]}... -> {e}")
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [h[0] for h in hits]
        retrieved_texts = [h[1] for h in hits]

        # Reciprocal rank: first position among positives
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in positive_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k (multiple positives supported)
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in positive_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            # precision@k = (# positives in top-k) / k
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in positive_ids)
            precision_sums[pk] += (num_pos_in_topk / pk)

        per_query.append({
            "query_id": item.get("query_id"),
            "query": q,
            "positive_ids": positive_ids,
            "gold_text": gold_text,
            "retrieved_ids": retrieved_ids,
            "retrieved_texts_preview": [t[:300] for t in retrieved_texts],
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    mrr = sum(rr_list) / n
    recall_at = {rk: recall_counts[rk] / n for rk in recall_ks}
    precision_at = {pk: precision_sums[pk] / n for pk in precision_ks}
    latency_mean = statistics.mean(latencies) if latencies else 0.0
    latency_median = statistics.median(latencies) if latencies else 0.0

    summary = {
        "n_queries": n,
        "MRR": mrr,
        **{f"Recall@{rk}": recall_at[rk] for rk in recall_ks},
        **{f"Precision@{pk}": precision_at[pk] for pk in precision_ks},
        "latency_mean_s": latency_mean,
        "latency_median_s": latency_median
    }

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in per_query:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return summary, per_query

# Run it
if 'eval_queries' not in globals():
    # try to load from file if not in memory
    eval_queries = []
    with open("eval_queries.jsonl","r",encoding="utf-8") as f:
        for line in f:
            eval_queries.append(json.loads(line))

summary, records = evaluate_bm25_with_positive_ids(eval_queries, out_jsonl=OUT_JSONL, k=DEFAULT_K)
print("BM25 retrieval evaluation summary:")
for k,v in summary.items():
    print(f"  {k}: {v}")

# show a few examples where retrieval missed positives
misses = [r for r in records if r["reciprocal_rank"] == 0.0]
print(f"\nTotal misses: {len(misses)} / {len(records)}. Showing up to 5 misses:")
for r in misses[:5]:
    print("Query id:", r.get("query_id"), "Query:", r["query"][:80])
    print(" Positives:", r["positive_ids"])
    print(" Retrieved top ids:", r["retrieved_ids"][:8])
    print()


BM25 retrieval evaluation summary:
  n_queries: 100
  MRR: 0.8853333333333333
  Recall@1: 0.84
  Recall@5: 0.95
  Precision@1: 0.84
  Precision@5: 0.21599999999999964
  latency_mean_s: 0.00026865243911743163
  latency_median_s: 0.0002652406692504883

Total misses: 5 / 100. Showing up to 5 misses:
Query id: q007 Query: کووڈ-19 ویکسین کا بنیادی مقصد کیا ہے؟
 Positives: ['p0007']
 Retrieved top ids: ['p0028', 'p0050', 'p0051', 'p0027', 'p0039']

Query id: q019 Query: وینٹیلیشن وبا کے دوران کیوں اہم ہے؟
 Positives: ['p0020']
 Retrieved top ids: ['p0017', 'p0060', 'p0031', 'p0048', 'p0027']

Query id: q038 Query: ویکسین سائیڈ ایفیکٹس کی نگرانی کیسے کی جاتی ہے؟
 Positives: ['p0039']
 Retrieved top ids: ['p0058', 'p0040', 'p0032', 'p0051', 'p0011']

Query id: q065 Query: ویکسین کی سائیڈ ایفیکٹس کی رپورٹنگ کیسے ہوتی ہے؟
 Positives: ['p0039']
 Retrieved top ids: ['p0058', 'p0047', 'p0032', 'p0022', 'p0025']

Query id: q095 Query: وبا کے دوران معاشی بحالی کے لیے کون سے اقدامات کیے جا سکتے ہیں؟
 

The results/accuracy of the sparse retriever model are exceptional and we are safe to move on to the dense retriever model.