## **0. LOAD LIBERARY**

In [None]:
# install the HuggingFace datasets library
!pip install datasets



In [None]:
# import liberary
from datasets import load_dataset
from collections import defaultdict
from typing import Dict, Any, List, Optional
import os
import json
import random


In [None]:
# Mount Google Drive (for Colab)
!fusermount -u /content/drive
!rm -rf /content/drive

from google.colab import drive
drive.mount('/content/drive')

fusermount: failed to unmount /content/drive: No such file or directory
Mounted at /content/drive


## **1. LOAD AND SAVE DATA TO DRIVE (MedQA, PubMedQA, MedQuAD, emrQA)**

In [None]:
def make_source_name(dataset_path: str, config: str|None):
    base = dataset_path.split("/")[-1]
    return f"{base}_{config}" if config else base

In [None]:
# 1. Define all datasets + configs
DATASETS = [
    # Med QA
    ("truehealth/medqa", None),
    ("Eladio/emrqa-msquad", None),
    ("qiaojin/PubMedQA", "pqa_labeled"),
    ("lavita/MedQuAD", None)
]

# 2. Target save path (inside Google Drive)
BASE_DIR = "/content/drive/MyDrive/data_source/raw_data"
os.makedirs(BASE_DIR, exist_ok=True)

# 3. Merge all splits into ONE JSONL file per dataset source
for dataset_path, config in DATASETS:
    print(f"Loading {dataset_path} {f'[{config}]' if config else ''} ...")
    ds = load_dataset(dataset_path, config) if config else load_dataset(dataset_path)
    source_name = make_source_name(dataset_path, config)
    save_path = os.path.join(BASE_DIR, f"{source_name}.jsonl")
    count = 0
    with open(save_path, "w", encoding="utf-8") as f:
        for split in ds.keys():       # iterate over available splits (train, validation, test, etc.)
            for ex in ds[split]:
                f.write(json.dumps(ex, ensure_ascii=False) + "\n")
                count += 1

    print(f"Saved {source_name}: {count} examples -> {save_path}")

print("Done: one JSONL per source dataset, all splits merged.")

Loading truehealth/medqa  ...
Saved medqa: 12723 examples -> /content/drive/MyDrive/data_source/raw_data/medqa.jsonl
Loading Eladio/emrqa-msquad  ...
Saved emrqa-msquad: 163695 examples -> /content/drive/MyDrive/data_source/raw_data/emrqa-msquad.jsonl
Loading qiaojin/PubMedQA [pqa_labeled] ...
Saved PubMedQA_pqa_labeled: 1000 examples -> /content/drive/MyDrive/data_source/raw_data/PubMedQA_pqa_labeled.jsonl
Loading lavita/MedQuAD  ...
Saved MedQuAD: 47441 examples -> /content/drive/MyDrive/data_source/raw_data/MedQuAD.jsonl
Done: one JSONL per source dataset, all splits merged.


In [None]:
# 4. Print dataset sample
BASE_DIR = "/content/drive/MyDrive/data_source/raw_data"
for filename in os.listdir(BASE_DIR):
    if filename.endswith(".jsonl"):
        file_path = os.path.join(BASE_DIR, filename)
        print(f"\n=== Dataset: {filename} ===")
        with open(file_path, "r", encoding="utf-8") as f:
            lines = f.readlines()
        for line in random.sample(lines, min(5, len(lines))):
            sample = json.loads(line)
            print(json.dumps(sample, indent=5, ensure_ascii=False))


=== Dataset: emrqa-msquad.jsonl ===
{
     "context": "This is a 47-year-old female with a history of HIV, diabetes, questionable cerebral aneurysm, and seizure disorder who recently had two syncopal events without prodrome and without postictal state, who presented for evaluation of left arm paresthesias and chest pain, with associated diaphoresis, shortness of breath and nausea. Of note, the patient recently started Flexeril to treat chronic low back pain, was not receiving her Keppra for approximately a year, as her prescription had ran out, and was instead taking Ecotrin 81 mg daily, clonazepam 1 mg q.6 h. p.r.n., Imodium one to two tablets q.i.d. p.r.n. for diarrhea, and low-dose aspirin. The patient was started on low-dose beta-blocker and aspirin, metoprolol 12.5 b.i.d. with occasional bradycardia to the high 40's, and was treated with the Ryo Hospital Medical Center insulin protocol. The patient was restarted on Keppra 250 mg b.i.d. with a goal to increase to 500 mg b.i.d. aft

## **2. Standardized dataset (MedQA, PubMedQA, MedQuAD, emrQA)**

In [None]:
import os
import json
import random
from collections import defaultdict
from typing import List, Dict, Any, Optional

# CONFIG
IN_DIR  = "/content/drive/MyDrive/data_source/raw_data"        # Directory containing the original *.jsonl files
OUT_DIR = "/content/drive/MyDrive/data_source/unified_format"  # Directory to write the normalized files
REPRO_SEED = 42
SAMPLE_RECORDS = 1500

os.makedirs(OUT_DIR, exist_ok=True)
random.seed(REPRO_SEED)

# UTILS
def write_jsonl(path: str, items: List[Dict[str, Any]]):
    """Writes a list of dictionaries to a JSONL file."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for ex in items:
            f.write(json.dumps(ex, ensure_ascii=False) + "\n")

def detect_dataset_name(fn: str) -> str:
    """Detects the dataset name from a given filename."""
    f = fn.lower()
    # Use more specific names and correct order to ensure accurate matching.
    # 'medquad' must be checked before 'medqa' to avoid misidentification.
    if "emrqa-msquad" in f:
        return "emrqa"
    if "pubmedqa_pqa_labeled" in f:
        return "pubmedqa"
    if "medquad" in f: # Checks for 'MedQuAD.jsonl'
        return "medquad"
    if "medqa" in f: # Checks for 'medqa.jsonl'
        return "medqa"

    return "unknown"

LETTERS = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")

# NORMALIZERS
def norm_emrqa(ex: Dict[str,Any]) -> Dict[str,Any]:
    """
    Normalizes emrqa-msquad (Short Answer / Extractive).
    - 'answer' will contain the exact answer text for F1-score calculation.
    - 'text' contains a list with just that single answer.
    - 'encode' is None as this is not an MCQ task.
    """
    answer_text = ex["answers"]["text"][0] if ex.get("answers", {}).get("text") else ""
    return {
        "id": ex.get("id", (ex.get("context","")[:20] + ex.get("question","")[:20]).replace(" ","_")),
        "dataset": "emrqa",
        "split": "all",
        "context": ex.get("context", None),
        "question": ex.get("question", ""),
        "text": [answer_text],  # A list with only one element
        "encode": None,         # No A,B,C options
        "answer": answer_text   # Save the original text for evaluation
    }

def norm_medqa(ex: Dict[str,Any]) -> Dict[str,Any]:
    """
    Normalizes medqa (MCQ).
    - Converts to a standard format with a question, choices, and the correct answer key (A,B,C,D,E).
    """
    options_dict = ex.get("options", {})
    sorted_keys = sorted(options_dict.keys())

    return {
        "id": ex.get("id", ex.get("question", "")[:40].replace(" ","_")),
        "dataset": "medqa",
        "split": "all",
        "context": None,
        "question": ex.get("question", ""),
        "text": [options_dict[key] for key in sorted_keys],
        "encode": sorted_keys,
        "answer": ex.get("answer_idx", None)
    }

def norm_pubmedqa(ex: Dict[str,Any]) -> Dict[str,Any]:
    """
    Normalizes PubMedQA (Yes/No/Maybe).
    - Converts it to an MCQ format with 3 choices: A, B, C.
    """
    answer_map = {"yes": "A", "no": "B", "maybe": "C"}
    correct_answer_key = answer_map.get(ex.get("final_decision"))

    return {
        "id": str(ex.get("pubid", ex.get("question", "")[:40].replace(" ","_"))),
        "dataset": "pubmedqa",
        "split": "all",
        "context": "\n".join(ex.get("context", {}).get("contexts", [])),
        "question": ex.get("question", ""),
        "text": ["yes", "no", "maybe"],
        "encode": ["A", "B", "C"],
        "answer": correct_answer_key
    }

def norm_medquad(ex: Dict[str,Any]) -> Optional[Dict[str,Any]]:
    """
    Normalizes MedQuAD (Short Answer / Generative).
    - Skips entries where the answer is null.
    """
    answer_text = ex.get("answer")
    if answer_text is None:
        return None

    return {
        "id": ex.get("question_id", ex.get("question", "")[:40].replace(" ","_")),
        "dataset": "medquad",
        "split": "all",
        "context": None,
        "question": ex.get("question", ""),
        "text": [answer_text],
        "encode": None,
        "answer": answer_text
    }

def normalize_line(raw: Dict[str,Any], dsname: str) -> Optional[Dict[str,Any]]:
    """Calls the appropriate normalization function based on the dataset name."""
    if dsname == "emrqa":
        return norm_emrqa(raw)
    if dsname == "medqa":
        return norm_medqa(raw)
    if dsname == "pubmedqa":
        return norm_pubmedqa(raw)
    if dsname == "medquad":
        return norm_medquad(raw)
    return None

# ================= MAIN SCRIPT =================

# LOAD → NORMALIZE → BUCKET (dataset/split)
buckets = defaultdict(lambda: defaultdict(list))

for fn in sorted(os.listdir(IN_DIR)):
    if not fn.endswith(".jsonl"):
        continue
    ds = detect_dataset_name(fn)
    if ds == "unknown":
        print(f"Skipping unknown file: {fn}")
        continue

    src = os.path.join(IN_DIR, fn)
    print(f"Processing {src} as dataset '{ds}'...")

    with open(src, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            try:
                raw = json.loads(line)
                ex  = normalize_line(raw, ds)
                if ex is None:
                    continue

                split = ex.get("split","all")
                if split in ("validation","dev"):
                    split = "val"
                ex["split"] = split
                buckets[ex["dataset"]][split].append(ex)
            except json.JSONDecodeError:
                print(f"  - Warning: Skipping invalid JSON on line {i+1} in {fn}")
            except Exception as e:
                print(f"  - Warning: An unexpected error occurred on line {i+1} in {fn}: {e}")

# WRITE OUT per dataset / per split
for ds, split_map in buckets.items():
    for split, items in split_map.items():
        out_path = os.path.join(OUT_DIR, ds, f"{split}.jsonl")
        write_jsonl(out_path, items[0:SAMPLE_RECORDS])
        print(f"Wrote {len(items[0:SAMPLE_RECORDS]):6d} → {out_path}")

# Manifest for easy lookup
manifest = []
for ds, split_map in buckets.items():
    for split, items in split_map.items():
        manifest.append({
            "dataset": ds,
            "split": split,
            "count": len(items),
            "path": os.path.join(OUT_DIR, ds, f"{split}.jsonl")
        })
write_jsonl(os.path.join(OUT_DIR, "_manifest.jsonl"), manifest)
print("\nDone. Manifest at:", os.path.join(OUT_DIR, "_manifest.jsonl"))

Processing /content/drive/MyDrive/data_source/raw_data/MedQuAD.jsonl as dataset 'medquad'...
Processing /content/drive/MyDrive/data_source/raw_data/PubMedQA_pqa_labeled.jsonl as dataset 'pubmedqa'...
Processing /content/drive/MyDrive/data_source/raw_data/emrqa-msquad.jsonl as dataset 'emrqa'...
Processing /content/drive/MyDrive/data_source/raw_data/medqa.jsonl as dataset 'medqa'...
Wrote   1500 → /content/drive/MyDrive/data_source/unified_format/medquad/all.jsonl
Wrote   1000 → /content/drive/MyDrive/data_source/unified_format/pubmedqa/all.jsonl
Wrote   1500 → /content/drive/MyDrive/data_source/unified_format/emrqa/all.jsonl
Wrote   1500 → /content/drive/MyDrive/data_source/unified_format/medqa/all.jsonl

Done. Manifest at: /content/drive/MyDrive/data_source/unified_format/_manifest.jsonl


In [None]:
import os
import json

# --- CONFIG ---
UNIFIED_DIR = "/content/drive/MyDrive/data_source/unified_format"
DATASETS = ["medqa", "emrqa", "pubmedqa", "medquad"]
SAMPLE_COUNT = 1

# --- PRINTOUT TEST DATA SAMPLE
UNIFIED_DIR = "/content/drive/MyDrive/data_source/unified_format"
DATASETS = ["medqa", "emrqa", "pubmedqa", "medquad"]
SAMPLE_COUNT = 1
for ds_name in DATASETS:
    print(f"\n{'='*20} SAMPLES FOR: {ds_name.upper()} {'='*20}")
    file_path = os.path.join(UNIFIED_DIR, ds_name, "all.jsonl")
    with open(file_path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if i >= SAMPLE_COUNT: break
            print(json.dumps(json.loads(line), indent=2, ensure_ascii=False))


{
  "id": "A_23-year-old_pregnant_woman_at_22_weeks",
  "dataset": "medqa",
  "split": "all",
  "context": null,
  "question": "A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?",
  "text": [
    "Ampicillin",
    "Ceftriaxone",
    "Ciprofloxacin",
    "Doxycycline",
    "Nitrofurantoin"
  ],
  "encode": [
    "A",
    "B",
    "C",
    "D",
    "E"
  ],
  "answer": "E"
}

{
  "id": "The_patient_was_admiWhat_is_her_current_",
  "dataset": "emrqa",
  "split": "