# Healthcare RAG
Dataset: small subset of Hugging Face
MedRAG (Wikipedia + PubMed)
Retriever: FAISS + sentence-transformers
Queries: 5–10 healthcare questions
Evaluator: manual gold (Recall@k)

In [1]:
#insatll packages
!pip install -q datasets faiss-cpu sentence-transformers pandas pyarrow

In [2]:
# Use this only if jupyter notebook gives you tqdm error or "LookupError: <ContextVar name='shell_parent' at 0x000001DAE336D4E0>""
# import os, warnings
# os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# os.environ["HF_DATASETS_DISABLE_PROGRESS_BARS"] = "1"
# os.environ["DISABLE_TQDM"] = "1"
# from datasets.utils.logging import disable_progress_bar, set_verbosity_error
# disable_progress_bar(); set_verbosity_error()
# import tqdm
# def _no_tqdm(iterable=None, *a, **k): return iterable if iterable is not None else []
# tqdm.tqdm = _no_tqdm
# try:
#     import tqdm.auto as tauto; tauto.tqdm = _no_tqdm
# except Exception:
#     pass


In [None]:
from datasets import load_dataset
import pandas as pd

# Fetch the first 100 entries from each source 
wiki = load_dataset("MedRAG/wikipedia", split="train[:100]")
pubm = load_dataset("MedRAG/pubmed", split="train[:100]")

df_w = pd.DataFrame(wiki)
df_p = pd.DataFrame(pubm)

# Normalize text column names
if "text" not in df_w.columns: df_w["text"] = ""
if "text" not in df_p.columns:
    if "abstract" in df_p.columns: df_p = df_p.rename(columns={"abstract":"text"})
    else: df_p["text"] = ""

# Add 'source' column to track where text came from
df_w["source"] = "MedRAG/wikipedia"
df_p["source"] = "MedRAG/pubmed"

# Merge both datasets
df = pd.concat([df_w, df_p], ignore_index=True)

df["text"] = df["text"].astype(str); df["source"] = df["source"].astype(str)

print("Docs loaded:", len(df)); df.head(2)


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: aa936f1c-62dc-437f-b45f-627aae4c5818)')' thrown while requesting HEAD https://huggingface.co/datasets/MedRAG/wikipedia/resolve/d76b9ad82135e352235d17e75921c49b68fd07b2/chunk/wiki20220301en139.jsonl
Retrying in 1s [Retry 1/5].


In [None]:

from sentence_transformers import SentenceTransformer
import numpy as np, faiss, textwrap, json, pathlib

texts = df["text"].tolist()

# Load a small pre-trained model from Hugging Face
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
X = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)

# Create a FAISS index
index = faiss.IndexFlatIP(X.shape[1])
index.add(X)

#search helper function
def retrieve(q, topk=5):
    qv = model.encode([q], normalize_embeddings=True, convert_to_numpy=True)
    D, I = index.search(qv, topk) # D=scores, I=indices
    out = []
    for rank, (idx, score) in enumerate(zip(I[0], D[0]), start=1):
        out.append({"rank":rank,
                    "idx":int(idx),
                    "source":df.iloc[idx]["source"],
                    "preview":textwrap.shorten(df.iloc[idx]["text"].replace("\n"," "), width=220),
                    "score":float(score)})
    return out


# Define 5 medical test questions
queries = [
    ("Q01","First-line antibiotic for community-acquired pneumonia in adults"),
    ("Q02","Imaging test of choice to confirm deep vein thrombosis"),
    ("Q03","Diagnostic criteria for type 2 diabetes mellitus"),
    ("Q04","Standard therapy for Helicobacter pylori infection"),
    ("Q05","Contraindications of ACE inhibitors")
]

# Retrieve results for above questions
all_runs = {}
for qid, q in queries:
    hits = retrieve(q, topk=5); all_runs[qid] = hits
    print("\n=== ", qid, ":", q, " ===")
    for h in hits: print(f"[{h['rank']}] {h['source']} | score={h['score']:.3f} | {h['preview']}")

runs_dir = pathlib.Path("runs_faiss_min"); runs_dir.mkdir(exist_ok=True)
for qid, hits in all_runs.items():
    with open(runs_dir / f"{qid}.json","w") as f: json.dump({"qid":qid,"results":hits}, f, indent=2)
print("\nSaved per-QID JSON to", runs_dir.resolve())


In [None]:

# simple "gold standard" where we manually mark which docs are correct
import json

#template
#fill gold_idxs (the doc indices that truly answer each question)
gold = [
    {"qid":"Q01","gold_idxs":[],"note":""},
    {"qid":"Q02","gold_idxs":[],"note":""},
    {"qid":"Q03","gold_idxs":[],"note":""},
    {"qid":"Q04","gold_idxs":[],"note":""},
    {"qid":"Q05","gold_idxs":[],"note":""}
]

#save template file
with open("gold_min.jsonl","w") as f:
    for row in gold: f.write(json.dumps(row)+"\n")
print("Gold template written to gold_min.jsonl")


# Functions to load gold + compute Recall@5 and MRR
def load_gold(path="gold_min.jsonl"):
    g={}
    for line in open(path):
        j=json.loads(line); g[j["qid"]]=set(j["gold_idxs"])
    return g

def recall_at_k(runs, gold, k=5):
    num=den=0
    for qid, gold_set in gold.items():
        if not gold_set: continue
        den+=1
        topk={h["idx"] for h in runs[qid][:k]}
        if gold_set & topk: num+=1
    return num/max(den,1)

def mrr(runs, gold):
    total=n=0
    for qid, gold_set in gold.items():
        if not gold_set: continue
        n+=1; rr=0.0
        for i,h in enumerate(runs[qid], start=1):
            if h["idx"] in gold_set: rr=1.0/i; break
        total+=rr
    return total/max(n,1)

print("After you fill gold_min.jsonl with idx values, run:")
print("g = load_gold(); print('Recall@5', recall_at_k(all_runs,g,5)); print('MRR', mrr(all_runs,g))")
