# Healthcare RAG
Dataset: small subset of Hugging Face
MedRAG (Wikipedia + PubMed)
Retriever: FAISS + sentence-transformers
Queries: 5–10 healthcare questions
Evaluator: manual gold (Recall@k)

In [1]:
#insatll packages
!pip install -q datasets faiss-cpu sentence-transformers pandas pyarrow

In [2]:
# Use this only if jupyter notebook gives you tqdm error or "LookupError: <ContextVar name='shell_parent' at 0x000001DAE336D4E0>""
# import os, warnings
# os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# os.environ["HF_DATASETS_DISABLE_PROGRESS_BARS"] = "1"
# os.environ["DISABLE_TQDM"] = "1"
# from datasets.utils.logging import disable_progress_bar, set_verbosity_error
# disable_progress_bar(); set_verbosity_error()
# import tqdm
# def _no_tqdm(iterable=None, *a, **k): return iterable if iterable is not None else []
# tqdm.tqdm = _no_tqdm
# try:
#     import tqdm.auto as tauto; tauto.tqdm = _no_tqdm
# except Exception:
#     pass


In [4]:
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm

# Fetch the data
wiki = load_dataset("MedRAG/wikipedia", split="train", streaming=True)
pubm = load_dataset("MedRAG/pubmed", split="train", streaming=True)

# first 3000 total records
wiki_sample = [next(iter(wiki)) for _ in tqdm(range(3000), desc="Fetching Wikipedia")]
pubm_sample = [next(iter(pubm)) for _ in tqdm(range(3000), desc="Fetching PubMed")]

df_w = pd.DataFrame(wiki_sample)
df_p = pd.DataFrame(pubm_sample)
print("Sample size:", len(df_w))

# Normalize text column names
if "text" not in df_w.columns: df_w["text"] = ""
if "text" not in df_p.columns:
    if "abstract" in df_p.columns: df_p = df_p.rename(columns={"abstract":"text"})
    else: df_p["text"] = ""

# # MED-QA (non-streaming)
# mediqa = load_dataset("bigbio/med_qa", split="train[:500]")  # small slice
# df_mq = pd.DataFrame(mediqa)
# # pick a reasonable context field or fallback to Q+A
# df_mq["text"] = (
#     df_mq.get("context")
#     or df_mq.get("passage")
#     or (df_mq.get("question", "").astype(str) + " " + df_mq.get("answer", "").astype(str))
# )
# df_mq = df_mq[["text"]].copy()

# PubMedQA (non-streaming)
pmqa = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train[:500]")
df_pmqa = pd.DataFrame(pmqa)
# map context/abstract into 'text', fallback to Q + long_answer
if "context" in df_pmqa.columns:
    df_pmqa = df_pmqa.rename(columns={"context": "text"})
elif "abstract" in df_pmqa.columns:
    df_pmqa = df_pmqa.rename(columns={"abstract": "text"})
else:
    df_pmqa["text"] = (
        df_pmqa.get("question", "").astype(str) + " " + df_pmqa.get("long_answer", "").astype(str)
    )
df_pmqa = df_pmqa[["text"]].copy()


# Add 'source' column to track where text came from
df_w["source"] = "MedRAG/wikipedia"
df_p["source"] = "MedRAG/pubmed"
# df_mq["source"] = "MEDQA"
df_pmqa["source"] = "PubMedQA"

# Merge, clean, quick stats
# df = pd.concat(
#     [df_w[["text","source"]], df_p[["text","source"]], df_mq[["text","source"]], df_pmqa[["text","source"]]],
#     ignore_index=True
# )

df = pd.concat(
    [df_w[["text","source"]], df_p[["text","source"]], df_pmqa[["text","source"]]],
    ignore_index=True
)

df["text"] = df["text"].astype(str)
df = df[df["text"].str.len() > 0].drop_duplicates(subset=["text", "source"]).reset_index(drop=True)

print("Docs loaded:", len(df))
print(df["source"].value_counts())
df.head(3)



Resolving data files:   0%|          | 0/646 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1166 [00:00<?, ?it/s]

Fetching Wikipedia: 100%|██████████| 3000/3000 [13:19<00:00,  3.75it/s]
Fetching PubMed: 100%|██████████| 3000/3000 [13:17<00:00,  3.76it/s]


Sample size: 3000


pqa_labeled/train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Docs loaded: 500
source
PubMedQA    500
Name: count, dtype: int64


Unnamed: 0,text,source
0,{'contexts': ['Programmed cell death (PCD) is ...,PubMedQA
1,{'contexts': ['Assessment of visual acuity dep...,PubMedQA
2,{'contexts': ['Apparent life-threatening event...,PubMedQA


In [7]:

from sentence_transformers import SentenceTransformer
import numpy as np, faiss, textwrap, json, pathlib

texts = df["text"].tolist()

# Load a small pre-trained model from Hugging Face
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
X = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)

# Create a FAISS index
index = faiss.IndexFlatIP(X.shape[1])
index.add(X)

#search helper function
def retrieve(q, topk=5):
    qv = model.encode([q], normalize_embeddings=True, convert_to_numpy=True)
    D, I = index.search(qv, topk) # D=scores, I=indices
    out = []
    for rank, (idx, score) in enumerate(zip(I[0], D[0]), start=1):
        out.append({"rank":rank,
                    "idx":int(idx),
                    "source":df.iloc[idx]["source"],
                    "preview":textwrap.shorten(df.iloc[idx]["text"].replace("\n"," "), width=220),
                    "score":float(score)})
    return out


# Define 15 medical test questions
queries = [
    ("Q01","First-line antibiotic for community-acquired pneumonia in adults"),
    ("Q02","Imaging test of choice to confirm deep vein thrombosis"),
    ("Q03","Diagnostic criteria for type 2 diabetes mellitus"),
    ("Q04","Standard therapy for Helicobacter pylori infection"),
    ("Q05","Contraindications of ACE inhibitors"),
    ("Q06", "Management of hypertensive emergency"),
    ("Q07", "Screening test for colorectal cancer"),
    ("Q08", "Treatment of anaphylaxis"),
    ("Q09", "Complications of untreated hypothyroidism"),
    ("Q10", "Adverse effects of corticosteroids")
]



# Retrieve results for above questions
all_runs = {}
for qid, q in queries:
    hits = retrieve(q, topk=5); all_runs[qid] = hits
    print("\n=== ", qid, ":", q, " ===")
    for h in hits: print(f"[{h['rank']}] {h['source']} | score={h['score']:.3f} | {h['preview']}")

runs_dir = pathlib.Path("runs_faiss_min"); runs_dir.mkdir(exist_ok=True)
for qid, hits in all_runs.items():
    with open(runs_dir / f"{qid}.json","w") as f: json.dump({"qid":qid,"results":hits}, f, indent=2)
print("\nSaved per-QID JSON to", runs_dir.resolve())


Batches:   0%|          | 0/16 [00:00<?, ?it/s]


===  Q01 : First-line antibiotic for community-acquired pneumonia in adults  ===
[1] PubMedQA | score=0.478 | {'contexts': ['We examined whether invasive lung-specimen collection-to-treatment times for intensive care unit patients with suspected ventilator-associated pneumonia (VAP) differ with to the work shift during [...]
[2] PubMedQA | score=0.474 | {'contexts': ['To determine the effect of the 2008 English public antibiotic campaigns.', 'English and Scottish (acting as controls) adults aged>or = 15 years were questioned face to face about their attitudes to [...]
[3] PubMedQA | score=0.424 | {'contexts': ["Little is known about the validity and reliability of expert assessments of the quality of antimicrobial prescribing, despite their importance in antimicrobial stewardship. We investigated how [...]
[4] PubMedQA | score=0.382 | {'contexts': ['Current guidelines for the treatment of uncomplicated urinary tract infection (UTI) in women recommend empiric therapy with antibiotics f

In [6]:

# simple "gold standard" where we manually mark which docs are correct
import json

#template
#fill gold_idxs (the doc indices that truly answer each question)
gold = [
    {"qid":"Q01","gold_idxs":[],"note":""},
    {"qid":"Q02","gold_idxs":[],"note":""},
    {"qid":"Q03","gold_idxs":[],"note":""},
    {"qid":"Q04","gold_idxs":[],"note":""},
    {"qid":"Q05","gold_idxs":[],"note":""}
]

#save template file
with open("gold_min.jsonl","w") as f:
    for row in gold: f.write(json.dumps(row)+"\n")
print("Gold template written to gold_min.jsonl")


# Functions to load gold + compute Recall@5 and MRR
def load_gold(path="gold_min.jsonl"):
    g={}
    for line in open(path):
        j=json.loads(line); g[j["qid"]]=set(j["gold_idxs"])
    return g

def recall_at_k(runs, gold, k=5):
    num=den=0
    for qid, gold_set in gold.items():
        if not gold_set: continue
        den+=1
        topk={h["idx"] for h in runs[qid][:k]}
        if gold_set & topk: num+=1
    return num/max(den,1)

def mrr(runs, gold):
    total=n=0
    for qid, gold_set in gold.items():
        if not gold_set: continue
        n+=1; rr=0.0
        for i,h in enumerate(runs[qid], start=1):
            if h["idx"] in gold_set: rr=1.0/i; break
        total+=rr
    return total/max(n,1)

print("After you fill gold_min.jsonl with idx values, run:")
print("g = load_gold(); print('Recall@5', recall_at_k(all_runs,g,5)); print('MRR', mrr(all_runs,g))")


Gold template written to gold_min.jsonl
After you fill gold_min.jsonl with idx values, run:
g = load_gold(); print('Recall@5', recall_at_k(all_runs,g,5)); print('MRR', mrr(all_runs,g))
