<a href="https://colab.research.google.com/github/johnycodestone/Optimized-Urdu-RAG-COVID-19/blob/main/Dense_and_Sparse_Retriever_Models_NLP_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1: Install required libraries (run this cell first and one by one all required libraries will be installed)
# - transformers: model + generation
# - sentence-transformers: dense embeddings / fine-tuning helpers
# - faiss-cpu (or faiss-gpu if GPU available)
# - rank_bm25: BM25 baseline
# - datasets: convenient JSONL loading
# - evaluate / sacrebleu: BLEU/chrF metrics
# - tqdm: progress bars
# - accelerate (optional) for distributed/faster training
!pip install -q transformers sentence-transformers faiss-cpu rank_bm25 datasets evaluate sacrebleu tqdm accelerate


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m114.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.8/100.8 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# Connect to google drive if not already connected
# 2. Mount Google Drive
# We need this to load your fine-tuned Dense Retriever and your Corpus file.
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Optional cell
# To add all the required files run
from google.colab import files
uploaded = files.upload()

In [None]:
!pip list # Optional to run this cell: To check which of the libraries/packages have been installed

In [3]:
# Cell 2: Imports and GPU check: Run this cell after the first cell
import os, json, time, math
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

# Transformers / sentence-transformers
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
import sentence_transformers # Import the package itself to access __version__

# FAISS and BM25
import faiss
from rank_bm25 import BM25Okapi

# Datasets and metrics
from datasets import load_dataset, Dataset
import evaluate
import sacrebleu

# Print versions and GPU info
print("transformers:", transformers.__version__)
print("sentence-transformers:", sentence_transformers.__version__)
try:
    import torch
    print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
except Exception as e:
    print("torch not available:", e)


transformers: 4.57.3
sentence-transformers: 5.2.0
torch: 2.9.0+cu126 cuda: True


In [4]:
# Cell 3: Load JSONL/TSV files into Python structures
# There will be a content folder on left side bar, files panel. This is our root
# folder. Inside it create a data folder, if not already present. Upload all files
# there and then run this cell.

DATA_DIR = Path("drive/MyDrive/data")  # change if files are elsewhere

# Create the data directory if it doesn't exist
import os
os.makedirs(DATA_DIR, exist_ok=True)

def load_jsonl(path):
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))
    return items

corpus_clean = load_jsonl(DATA_DIR / "urdu_covid_corpus_clean.jsonl")
passages_min = load_jsonl(DATA_DIR / "urdu_covid_passages_min.jsonl")
# TSV -> list of dicts
passages_tsv = []
with open(DATA_DIR / "urdu_covid_passages.tsv", "r", encoding="utf-8") as f:
    for line in f:
        # Use split(None, 1) to split on the first occurrence of any whitespace
        # This handles cases where the delimiter might be spaces instead of a tab.
        if line.strip(): # Ensure line is not empty after stripping whitespace
            parts = line.rstrip("\n").split(None, 1)
            if len(parts) == 2:
                pid, text = parts
                passages_tsv.append({"id": pid, "text": text})
            else:
                print(f"Skipping malformed line in urdu_covid_passages.tsv: {line.strip()}")

eval_queries = load_jsonl(DATA_DIR / "eval_queries.jsonl")
synthetic_pairs = load_jsonl(DATA_DIR / "synthetic_qa_pairs.jsonl")
hard_negatives = load_jsonl(DATA_DIR / "hard_negatives.jsonl")

print("Loaded:", len(corpus_clean), "corpus_clean; ", len(passages_min), "passages_min; ", len(eval_queries), "eval queries")


Loaded: 60 corpus_clean;  60 passages_min;  100 eval queries


In [5]:
# Cell 4: Validate IDs referenced in eval/synthetic/hard_negatives exist in corpus
# Run this after Cell 3.
passage_ids = {p["id"] for p in passages_min}
missing = []
for q in eval_queries:
    for pid in q.get("positive_ids", []):
        if pid not in passage_ids:
            missing.append(("eval", q["query_id"], pid))
for s in synthetic_pairs:
    if s["positive_id"] not in passage_ids:
        missing.append(("synthetic", s["synthetic_id"], s["positive_id"]))
for h in hard_negatives:
    for pid in h["hard_negatives"]:
        if pid not in passage_ids:
            missing.append(("hardneg", h["query_id"], pid))
print("Missing references (should be zero):", len(missing))
if missing:
    print(missing[:10])


Missing references (should be zero): 0


In [6]:
# Cell 5 (Run after Cell 4): BM25 baseline index (tokenize with simple whitespace; for Urdu this is OK as baseline)
# We'll store tokenized corpus and BM25 object for retrieval.
from nltk.tokenize import word_tokenize
# If nltk not installed, use simple split
try:
    import nltk
    nltk.download('punkt')
    nltk.download('punkt_tab') # Added to resolve LookupError for 'punkt_tab'
    tokenizer = lambda s: word_tokenize(s)
except Exception:
    tokenizer = lambda s: s.split()

corpus_texts = [p["text"] for p in passages_min]
corpus_ids = [p["id"] for p in passages_min]
tokenized_corpus = [tokenizer(t) for t in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)

# Example retrieval function
def bm25_retrieve(query, k=5):
    q_tokens = tokenizer(query)
    scores = bm25.get_scores(q_tokens)
    topk = np.argsort(scores)[::-1][:k]
    return [(corpus_ids[i], corpus_texts[i], float(scores[i])) for i in topk]

# Quick test
print("BM25 top-3 for sample:", bm25_retrieve(eval_queries[0]["query"], k=3))


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


BM25 top-3 for sample: [('p0001', 'کورونا وائرس مرض 2019 (COVID-19) ایک متعدی بیماری ہے جس کی عام علامات میں بخار، کھانسی اور سانس لینے میں دشواری شامل ہیں۔', 5.810063974702894), ('p0024', 'بچوں میں کووڈ-19 عام طور پر ہلکا ہوتا ہے مگر بعض نادر معاملات میں شدید علامات سامنے آ سکتی ہیں؛ بچوں کے لیے مخصوص رہنمائی مختلف ہو سکتی ہے۔', 5.103496839739362), ('p0002', 'کووڈ-19 کی تشخیص کے لیے rRT-PCR سویب ٹیسٹ عام طور پر استعمال ہوتے ہیں اور یہ وائرس کی موجودگی کی تصدیق کرتے ہیں۔', 4.589270107579207)]


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [7]:
# Cell 5b: BM25-only retriever evaluation tool (run after Cell 5)
# Purpose: standalone evaluation harness for the independent BM25 retriever (bm25_retrieve)
# Metrics included (applicable to a retriever-only evaluation):
#   - Recall@1, Recall@5
#   - MRR (Mean Reciprocal Rank)
#   - Precision@k (k=1,5)
#   - Average / median retrieval latency
#   - Optional: match by gold_passage_id or by substring match of gold_answer
# Output:
#   - Per-query JSONL saved to bm25_eval_results.jsonl
#   - Printed summary with all metrics
#
# Requirements (must be available in the session):
#   - bm25_retrieve(query, k) -> list of (passage_id, passage_text, score)
#   - eval_queries: list of dicts with at least a query field and optionally:
#       * "question" or "query" or "q"  (the query text)
#       * "gold_passage_id" (optional) OR "answer"/"gold" (gold text to match)
#
# Usage:
#   - Run this cell after you build the BM25 index (Cell 5).
#   - Optionally pass a different eval list or k values to evaluate subsets.

# Use this evaluator if your eval_queries items contain "positive_ids" and "gold_answer"
import json, time, re, statistics
from typing import List, Dict

OUT_JSONL = "bm25_eval_results.jsonl"
DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def normalize_text(s: str) -> str:
    if s is None: return ""
    s = str(s).strip()
    return re.sub(r"\s+", " ", s)

def get_query_text(item: Dict) -> str:
    return item.get("query") or item.get("question") or item.get("q") or ""

def evaluate_bm25_with_positive_ids(eval_items: List[Dict],
                                    out_jsonl: str = OUT_JSONL,
                                    k: int = DEFAULT_K,
                                    recall_ks = RECALL_KS,
                                    precision_ks = PRECISION_KS):
    per_query = []
    latencies = []
    rr_list = []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in eval_items:
        total += 1
        q = get_query_text(item)
        positive_ids = item.get("positive_ids") or item.get("positive_id") or []
        # normalize to list of strings
        if isinstance(positive_ids, str):
            positive_ids = [positive_ids]
        positive_ids = [str(x) for x in positive_ids]

        gold_text = normalize_text(item.get("gold_answer") or item.get("answer") or "")

        t0 = time.time()
        try:
            hits = bm25_retrieve(q, k=k)   # (id, text, score)
        except Exception as e:
            hits = []
            print(f"[eval] bm25_retrieve error for query {q[:60]}... -> {e}")
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [h[0] for h in hits]
        retrieved_texts = [h[1] for h in hits]

        # Reciprocal rank: first position among positives
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in positive_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k (multiple positives supported)
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in positive_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            # precision@k = (# positives in top-k) / k
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in positive_ids)
            precision_sums[pk] += (num_pos_in_topk / pk)

        per_query.append({
            "query_id": item.get("query_id"),
            "query": q,
            "positive_ids": positive_ids,
            "gold_text": gold_text,
            "retrieved_ids": retrieved_ids,
            "retrieved_texts_preview": [t[:300] for t in retrieved_texts],
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    mrr = sum(rr_list) / n
    recall_at = {rk: recall_counts[rk] / n for rk in recall_ks}
    precision_at = {pk: precision_sums[pk] / n for pk in precision_ks}
    latency_mean = statistics.mean(latencies) if latencies else 0.0
    latency_median = statistics.median(latencies) if latencies else 0.0

    summary = {
        "n_queries": n,
        "MRR": mrr,
        **{f"Recall@{rk}": recall_at[rk] for rk in recall_ks},
        **{f"Precision@{pk}": precision_at[pk] for pk in precision_ks},
        "latency_mean_s": latency_mean,
        "latency_median_s": latency_median
    }

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in per_query:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return summary, per_query

# Run it
if 'eval_queries' not in globals():
    # try to load from file if not in memory
    eval_queries = []
    with open("eval_queries.jsonl","r",encoding="utf-8") as f:
        for line in f:
            eval_queries.append(json.loads(line))

summary, records = evaluate_bm25_with_positive_ids(eval_queries, out_jsonl=OUT_JSONL, k=DEFAULT_K)
print("BM25 retrieval evaluation summary:")
for k,v in summary.items():
    print(f"  {k}: {v}")

# show a few examples where retrieval missed positives
misses = [r for r in records if r["reciprocal_rank"] == 0.0]
print(f"\nTotal misses: {len(misses)} / {len(records)}. Showing up to 5 misses:")
for r in misses[:5]:
    print("Query id:", r.get("query_id"), "Query:", r["query"][:80])
    print(" Positives:", r["positive_ids"])
    print(" Retrieved top ids:", r["retrieved_ids"][:8])
    print()


BM25 retrieval evaluation summary:
  n_queries: 100
  MRR: 0.8853333333333333
  Recall@1: 0.84
  Recall@5: 0.95
  Precision@1: 0.84
  Precision@5: 0.21599999999999964
  latency_mean_s: 0.00026865243911743163
  latency_median_s: 0.0002652406692504883

Total misses: 5 / 100. Showing up to 5 misses:
Query id: q007 Query: کووڈ-19 ویکسین کا بنیادی مقصد کیا ہے؟
 Positives: ['p0007']
 Retrieved top ids: ['p0028', 'p0050', 'p0051', 'p0027', 'p0039']

Query id: q019 Query: وینٹیلیشن وبا کے دوران کیوں اہم ہے؟
 Positives: ['p0020']
 Retrieved top ids: ['p0017', 'p0060', 'p0031', 'p0048', 'p0027']

Query id: q038 Query: ویکسین سائیڈ ایفیکٹس کی نگرانی کیسے کی جاتی ہے؟
 Positives: ['p0039']
 Retrieved top ids: ['p0058', 'p0040', 'p0032', 'p0051', 'p0011']

Query id: q065 Query: ویکسین کی سائیڈ ایفیکٹس کی رپورٹنگ کیسے ہوتی ہے؟
 Positives: ['p0039']
 Retrieved top ids: ['p0058', 'p0047', 'p0032', 'p0022', 'p0025']

Query id: q095 Query: وبا کے دوران معاشی بحالی کے لیے کون سے اقدامات کیے جا سکتے ہیں؟
 

The results/accuracy of the sparse retriever model are exceptional and we are safe to move on to the dense retriever model.

In [8]:
# Cell 6: Dense embeddings with a multilingual model (use a compact model for Colab)
# We use a multilingual SBERT model that supports Urdu reasonably (e.g., 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embed_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embedder = SentenceTransformer(embed_model_name)

# Compute embeddings for passages_min (batching)
passage_embeddings = embedder.encode(corpus_texts, show_progress_bar=True, convert_to_numpy=True)

# Build FAISS index (cosine similarity via normalized vectors)
d = passage_embeddings.shape[1]
index = faiss.IndexFlatIP(d)  # inner product
# normalize embeddings for cosine
faiss.normalize_L2(passage_embeddings)
index.add(passage_embeddings)

# Map index positions to ids
# retrieval function
def dense_retrieve(query, k=5):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, k)
    results = []
    for idx, score in zip(I[0], D[0]):
        results.append((corpus_ids[idx], corpus_texts[idx], float(score)))
    return results

# Quick test
print("Dense top-3:", dense_retrieve(eval_queries[0]["query"], k=3))


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/645 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/471M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Dense top-3: [('p0024', 'بچوں میں کووڈ-19 عام طور پر ہلکا ہوتا ہے مگر بعض نادر معاملات میں شدید علامات سامنے آ سکتی ہیں؛ بچوں کے لیے مخصوص رہنمائی مختلف ہو سکتی ہے۔', 0.7648559808731079), ('p0021', 'کووڈ-19 کے بعد بعض افراد میں طویل مدتی علامات (Long COVID) جیسے تھکن، سانس کی تکلیف اور دماغی دھند برقرار رہ سکتی ہیں؛ ریہیب پروگرامز مدد دیتے ہیں۔', 0.6735702753067017), ('p0036', 'کووڈ-19 کے مریضوں میں خون جمنے کے مسائل اور دیگر پیچیدگیاں بعض اوقات سامنے آئیں، اس لیے طبی نگرانی اور مناسب علاج ضروری ہے۔', 0.6472983360290527)]


In [12]:
# Cell 6b: Evaluation of dense retriever (run after Cell 6)
# Purpose: measure Recall@1, Recall@5, MRR, Precision@k, latency for dense_retrieve
# Uses eval_queries with "positive_ids" and "gold_answer" fields

import json, time, re, statistics

OUT_JSONL_DENSE = "dense_eval_results.jsonl"
DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def normalize_text(s):
    if s is None: return ""
    return re.sub(r"\s+", " ", str(s).strip())

def get_query_text(item):
    return item.get("query") or item.get("question") or item.get("q") or ""

def evaluate_dense(eval_items, out_jsonl=OUT_JSONL_DENSE, k=DEFAULT_K,
                   recall_ks=RECALL_KS, precision_ks=PRECISION_KS):
    per_query = []
    latencies, rr_list = [], []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in eval_items:
        total += 1
        q = get_query_text(item)
        pos_ids = item.get("positive_ids") or []
        if isinstance(pos_ids, str): pos_ids = [pos_ids]
        pos_ids = [str(x) for x in pos_ids]

        gold_text = normalize_text(item.get("gold_answer") or "")

        t0 = time.time()
        hits = dense_retrieve(q, k=k)  # (id, text, score)
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [h[0] for h in hits]
        retrieved_texts = [h[1] for h in hits]

        # Reciprocal rank
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in pos_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in pos_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in pos_ids)
            precision_sums[pk] += num_pos_in_topk / pk

        per_query.append({
            "query_id": item.get("query_id"),
            "query": q,
            "positive_ids": pos_ids,
            "gold_text": gold_text,
            "retrieved_ids": retrieved_ids,
            "retrieved_texts_preview": [txt[:300] for txt in retrieved_texts],
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    summary = {
        "n_queries": n,
        "MRR": sum(rr_list)/n,
        **{f"Recall@{rk}": recall_counts[rk]/n for rk in recall_ks},
        **{f"Precision@{pk}": precision_sums[pk]/n for pk in precision_ks},
        "latency_mean_s": statistics.mean(latencies) if latencies else 0.0,
        "latency_median_s": statistics.median(latencies) if latencies else 0.0
    }

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in per_query:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return summary, per_query

# Run evaluation
print("[dense_eval] Running dense retriever evaluation...")
summary_dense, records_dense = evaluate_dense(eval_queries, out_jsonl=OUT_JSONL_DENSE, k=DEFAULT_K)
print("\nDense retriever evaluation summary:")
for k,v in summary_dense.items():
    print(f"  {k}: {v}")

# Show a few examples
print("\nExamples (first 5):")
for r in records_dense[:5]:
    print(" - Query:", r["query"][:80])
    print("   Retrieved ids:", r["retrieved_ids"][:6])
    print("   Reciprocal rank:", r["reciprocal_rank"], "Latency(s):", round(r["latency"], 4))
    print()


[dense_eval] Running dense retriever evaluation...

Dense retriever evaluation summary:
  n_queries: 100
  MRR: 0.8141666666666667
  Recall@1: 0.72
  Recall@5: 0.94
  Precision@1: 0.72
  Precision@5: 0.21199999999999963
  latency_mean_s: 0.009402217864990235
  latency_median_s: 0.009119868278503418

Examples (first 5):
 - Query: کووڈ-19 کی عام علامات کیا ہیں؟
   Retrieved ids: ['p0024', 'p0021', 'p0001', 'p0036', 'p0044']
   Reciprocal rank: 0.3333333333333333 Latency(s): 0.0113

 - Query: کووڈ-19 کی تشخیص کے لیے کون سا ٹیسٹ عام طور پر استعمال ہوتا ہے؟
   Retrieved ids: ['p0002', 'p0018', 'p0040', 'p0024', 'p0036']
   Reciprocal rank: 1.0 Latency(s): 0.0096

 - Query: ہاتھوں کی صفائی وبا کے دوران کیوں ضروری ہے؟
   Retrieved ids: ['p0030', 'p0003', 'p0042', 'p0041', 'p0020']
   Reciprocal rank: 0.5 Latency(s): 0.0088

 - Query: ماسک پہننے کے کیا فوائد ہیں؟
   Retrieved ids: ['p0004', 'p0029', 'p0042', 'p0022', 'p0046']
   Reciprocal rank: 1.0 Latency(s): 0.0094

 - Query: سماجی فاصلہ رک

In [10]:
# Cell 7: Prepare InputExamples for sentence-transformers fine-tuning i.e. of dense retriever model
# Now with an 80/20 train/validation split

from sentence_transformers import InputExample
import random

pid2text = {p["id"]: p["text"] for p in passages_min}

examples = []
for s in synthetic_pairs:
    q = s["query"]
    pos = pid2text.get(s["positive_id"])
    neg = None
    # Find hard negatives if available
    hn = next((h for h in hard_negatives if h["query_id"] == s.get("synthetic_id", s.get("query_id"))), None)
    if hn:
        for nid in hn["hard_negatives"]:
            if nid != s["positive_id"]:
                neg = pid2text.get(nid)
                break
    if neg is None:
        # fallback: random negative
        neg_id = random.choice([pid for pid in corpus_ids if pid != s["positive_id"]])
        neg = pid2text[neg_id]
    if pos and neg:
        examples.append(InputExample(texts=[q, pos, neg]))

print("Prepared", len(examples), "triplet examples.")

# --- Split into train/validation (80/20) ---
random.shuffle(examples)
split_idx = int(0.8 * len(examples))
train_examples = examples[:split_idx]
val_examples = examples[split_idx:]

print("Train examples:", len(train_examples))
print("Validation examples:", len(val_examples))


Prepared 500 triplet examples.
Train examples: 400
Validation examples: 100


In [11]:
# Cell 8 (use in-memory model; do NOT reload): Fine-tune SBERT with triplet loss and IR validation on passages_min
import os
# --- GRANDMASTER FIX: DISABLE WANDB ---
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
# --------------------------------------

from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, evaluation
#import faiss

# Sanity checks
assert isinstance(train_examples, list) and len(train_examples) > 0, "train_examples must be a non-empty list"
assert 'passages_min' in globals(), "passages_min must be loaded"
assert 'eval_queries' in globals(), "eval_queries must be loaded"

# Build validation split against real corpus & labels
# (Check if eval_queries_val exists, otherwise split eval_queries)
eval_val = eval_queries_val if 'eval_queries_val' in globals() else eval_queries[int(0.8*len(eval_queries)):]

val_queries_dict = {it["query_id"]: it["query"] for it in eval_val}
# Fix: Ensure positive_ids is a list
val_relevant_dict = {it["query_id"]: set(it["positive_ids"] if isinstance(it["positive_ids"], list) else [it["positive_ids"]]) for it in eval_val}
val_corpus_dict = {p["id"]: p["text"] for p in passages_min}

# Warn if labels reference missing ids
missing = []
for qid, rels in val_relevant_dict.items():
    for pid in rels:
        if pid not in val_corpus_dict:
            missing.append((qid, pid))
if missing:
    print(f"Warning: {len(missing)} relevant ids not found in corpus. Example:", missing[:3])

# Construct evaluator (defaults to cosine similarity)
retrieval_evaluator = evaluation.InformationRetrievalEvaluator(
    queries=val_queries_dict,
    corpus=val_corpus_dict,
    relevant_docs=val_relevant_dict,
    name="val_ir_passages"
)

# Start from baseline multilingual MiniLM
# We use the variable 'embedder' from Cell 6 to ensure we continue correctly
if 'embedder' not in globals():
    embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
embedder.to("cuda")

# Triplet loss with conservative settings
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(
    model=embedder,
    distance_metric=losses.TripletDistanceMetric.COSINE,
    triplet_margin=0.3
)

num_epochs = 2
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
optimizer_params = {'lr': 2e-5}

print("Starting fine-tuning (WandB Disabled)...")

# Train with IR evaluator
embedder.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=retrieval_evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    optimizer_params=optimizer_params,
    show_progress_bar=True,
    output_path="fine_tuned_sbert_urdu_passages"
)

print("✅ Fine-tuning complete. Using in-memory fine-tuned 'embedder' (no reload).")

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting fine-tuning (WandB Disabled)...


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Val Ir Passages Cosine Accuracy@1,Val Ir Passages Cosine Accuracy@3,Val Ir Passages Cosine Accuracy@5,Val Ir Passages Cosine Accuracy@10,Val Ir Passages Cosine Precision@1,Val Ir Passages Cosine Precision@3,Val Ir Passages Cosine Precision@5,Val Ir Passages Cosine Precision@10,Val Ir Passages Cosine Recall@1,Val Ir Passages Cosine Recall@3,Val Ir Passages Cosine Recall@5,Val Ir Passages Cosine Recall@10,Val Ir Passages Cosine Ndcg@10,Val Ir Passages Cosine Mrr@10,Val Ir Passages Cosine Map@100
25,No log,No log,0.75,0.9,0.9,0.9,0.75,0.383333,0.24,0.125,0.525,0.75,0.775,0.8,0.74968,0.808333,0.712652
50,No log,No log,0.75,0.85,0.9,0.95,0.75,0.383333,0.24,0.13,0.525,0.75,0.775,0.825,0.757199,0.806667,0.713152


✅ Fine-tuning complete. Using in-memory fine-tuned 'embedder' (no reload).


In [None]:
# Cell 8b: Save the Fine-Tuned Model to Drive (Run ONLY if satisfied with accuracy)
import os

# Define path
MODEL_SAVE_PATH = "/content/drive/MyDrive/models/urdu_dense_retriever_best"

print(f"💾 Saving model to {MODEL_SAVE_PATH} ...")

# Create directory if not exists
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Save the model
embedder.save(MODEL_SAVE_PATH)

print(f"✅ Model saved! You can now use Cell 8c in future sessions to skip training.")

In [None]:
# Cell 8c: FAST START - Load Model from Drive & Rebuild FAISS (Skips Training)
# Run this INSTEAD of Cells 6, 7, 8, 8b in future sessions.

import os
import faiss
from sentence_transformers import SentenceTransformer

MODEL_SAVE_PATH = "/content/drive/MyDrive/models/urdu_dense_retriever_best"

# 1. Load the Model
if os.path.exists(MODEL_SAVE_PATH):
    print(f"📂 Loading saved model from: {MODEL_SAVE_PATH}")
    embedder = SentenceTransformer(MODEL_SAVE_PATH).to("cuda")
    print("✅ Model loaded successfully.")
else:
    raise FileNotFoundError(f"❌ No saved model found at {MODEL_SAVE_PATH}. Please run Cell 8 & 8b once to create it!")

# 2. Rebuild FAISS Index (Critical Step)
# We must re-encode the corpus because we just loaded a specific model
print("⏳ Generating embeddings for corpus...")
corpus_texts = [p["text"] for p in passages_min]

# Generate embeddings
passage_embeddings = embedder.encode(corpus_texts, show_progress_bar=True, convert_to_numpy=True)

# Build FAISS
faiss.normalize_L2(passage_embeddings)
index = faiss.IndexFlatIP(passage_embeddings.shape[1])
index.add(passage_embeddings)

# 3. Define the Retrieval Function
# (We must re-define this here because we skipped the previous cells that defined it)
def dense_retrieve(query, k=5):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, k)
    results = []
    for idx, score in zip(I[0], D[0]):
        results.append((corpus_ids[idx], corpus_texts[idx], float(score)))
    return results

print("✅ Dense Retriever System Restored & Ready for Hybrid Fusion (Cell 9).")

We can now run cell 6b again to test the improvement of our dense retriever model after fine tuning.
The results of our dense retriver model are satisfactory and thus we move on to the hybrid retriever model.