In [None]:
# make_pairs.py
# Build DPO preference pairs (prompt, chosen, rejected) from medqa_50, medmcqa_50, pubmedqa_50
# Each sample = one prompt (same as evaluation template) + correct vs wrong answer
# Output: dpo_pairs.jsonl

import json, os, random
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from tqdm import tqdm
from transformers import AutoTokenizer

# ------------------------
# Configuration
# ------------------------
MEDQA_PATH    = "medqa_50.json"
MEDMCQA_PATH  = "medmcqa_50.json"
PUBMEDQA_PATH = "pubmedqa_50.json"

MODEL_REPO    = "meta-llama/Llama-3.1-8B-Instruct"
USE_CHAT      = True
OUTPUT_JSONL  = "dpo_pairs.jsonl"

MAX_NEG_PER_ITEM = 2   # number of negative (rejected) samples per question
SEED = 42
random.seed(SEED)

# ------------------------
# Tokenizer for chat template (same as evaluation)
# ------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, use_fast=True)

def apply_chat_template(user_msg: str, system_msg: str = "") -> str:
    msgs = []
    if system_msg:
        msgs.append({"role": "system", "content": system_msg})
    msgs.append({"role": "user", "content": user_msg})
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

# ------------------------
# Data Structures
# ------------------------
@dataclass
class MCItem:
    question: str
    options: Dict[str, str]
    answer_letter: str
    source_id: Optional[str] = None

@dataclass
class YesNoMaybeItem:
    question: str
    contexts: List[str]
    gold_label: str
    source_id: Optional[str] = None

# ------------------------
# JSON Loader Helpers
# ------------------------
def _read_json_any(path: str) -> Union[dict, list]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

# ------------------------
# Dataset Loaders
# ------------------------
def load_medqa(path: str) -> List[MCItem]:
    raw = _read_json_any(path)
    items: List[MCItem] = []
    iterator = raw.items() if isinstance(raw, dict) else enumerate(raw)
    for key, ex in tqdm(iterator, desc="Loading MedQA", ncols=80):
        q = str(ex.get("question", "")).strip()
        opts_in = ex.get("options", {})
        opts = {k.upper(): str(v) for k, v in opts_in.items() if k.upper() in ["A","B","C","D","E"]}
        ans = str(ex.get("answer_idx", ex.get("answer",""))).strip().upper()
        if ans not in opts:
            inv = {v.strip(): k for k, v in opts.items()}
            ans = inv.get(ans, "")
        if not q or len(opts) < 2 or ans not in opts:
            continue
        items.append(MCItem(q, opts, ans, str(key)))
    return items

def load_medmcqa(path: str) -> List[MCItem]:
    raw = _read_json_any(path)
    items: List[MCItem] = []
    idx_to_letter = {1:"A",2:"B",3:"C",4:"D",5:"E"}
    strnum_to_letter = {"1":"A","2":"B","3":"C","4":"D","5":"E"}
    iterator = raw.items() if isinstance(raw, dict) else enumerate(raw)
    for key, ex in tqdm(iterator, desc="Loading MedMCQA", ncols=80):
        q = str(ex.get("question","")).strip()
        opts: Dict[str,str] = {}
        if isinstance(ex.get("options"), dict):
            for k, v in ex["options"].items():
                kk = str(k).strip().upper()
                if kk in ["A","B","C","D","E"]:
                    opts[kk] = str(v)
        else:
            for L, fld in {"A":"opa","B":"opb","C":"opc","D":"opd","E":"ope"}.items():
                if fld in ex and ex[fld] is not None:
                    opts[L] = str(ex[fld])

        gold_raw = ex.get("cop", ex.get("answer_idx", ex.get("answer", ex.get("label",""))))
        gold = ""
        if isinstance(gold_raw, int):
            gold = idx_to_letter.get(gold_raw, "")
        else:
            s = str(gold_raw).strip()
            if s in strnum_to_letter: gold = strnum_to_letter[s]
            elif len(s)==1 and s.lower() in "abcde": gold = s.upper()
            elif s.upper() in ["A","B","C","D","E"]: gold = s.upper()
            else:
                inv = {v.strip(): k for k, v in opts.items()}
                gold = inv.get(s, "")
        if not q or len(opts) < 2 or gold not in opts:
            continue
        items.append(MCItem(q, opts, gold, str(key)))
    return items

def load_pubmedqa(path: str) -> List[YesNoMaybeItem]:
    raw = _read_json_any(path)
    items: List[YesNoMaybeItem] = []
    iterator = raw.items() if isinstance(raw, dict) else enumerate(raw)
    for key, ex in tqdm(iterator, desc="Loading PubMedQA", ncols=80):
        q = str(ex.get("QUESTION", ex.get("question",""))).strip()
        ctx = ex.get("CONTEXTS", ex.get("contexts", []))
        if not isinstance(ctx, list): ctx = [str(ctx)]
        gold = str(ex.get("final_decision", ex.get("answer",""))).strip().lower()
        if not q or gold not in {"yes","no","maybe"}:
            continue
        items.append(YesNoMaybeItem(q, [str(c) for c in ctx], gold, str(key)))
    return items

# ------------------------
# Prompt Templates (identical to evaluation)
# ------------------------
def mc_prompt(item: MCItem) -> str:
    letters = "".join(sorted(item.options.keys()))
    opts = "\n".join([f"{k}. {v}" for k, v in item.options.items()])
    user = (
        "You are answering a multiple-choice medical question.\n"
        "Return ONLY one uppercase letter from the allowed set.\n\n"
        f"Question:\n{item.question}\n\nOptions:\n{opts}\n\n"
        f"Answer with ONLY ONE LETTER from [{letters}].\nAnswer:"
    )
    return apply_chat_template(user) if USE_CHAT else user

def pubmedqa_prompt(item: YesNoMaybeItem) -> str:
    ctx = "\n".join(f"- {c}" for c in item.contexts[:6])
    user = (
        "You are assessing a biomedical yes/no/maybe question.\n"
        "Return ONLY one token: yes, no, or maybe (lowercase).\n\n"
        f"Question:\n{item.question}\n\nEvidence:\n{ctx}\n\nAnswer:"
    )
    return apply_chat_template(user) if USE_CHAT else user

# ------------------------
# Pair Construction
# ------------------------
def build_pairs_medqa(items: List[MCItem]) -> List[dict]:
    rows = []
    for it in tqdm(items, desc="Building MedQA pairs", ncols=80):
        prompt = mc_prompt(it)
        correct = it.answer_letter
        wrongs = [L for L in sorted(it.options.keys()) if L != correct]
        random.shuffle(wrongs)
        for neg in wrongs[:MAX_NEG_PER_ITEM]:
            rows.append({
                "prompt": prompt,
                "chosen": correct,
                "rejected": neg,
                "meta": {"dataset": "MedQA", "id": it.source_id, "type": "mcq"}
            })
    return rows

def build_pairs_medmcqa(items: List[MCItem]) -> List[dict]:
    rows = []
    for it in tqdm(items, desc="Building MedMCQA pairs", ncols=80):
        prompt = mc_prompt(it)
        correct = it.answer_letter
        wrongs = [L for L in sorted(it.options.keys()) if L != correct]
        random.shuffle(wrongs)
        for neg in wrongs[:MAX_NEG_PER_ITEM]:
            rows.append({
                "prompt": prompt,
                "chosen": correct,
                "rejected": neg,
                "meta": {"dataset": "MedMCQA", "id": it.source_id, "type": "mcq"}
            })
    return rows

def build_pairs_pubmedqa(items: List[YesNoMaybeItem]) -> List[dict]:
    rows = []
    labels = ["yes","no","maybe"]
    for it in tqdm(items, desc="Building PubMedQA pairs", ncols=80):
        prompt = pubmedqa_prompt(it)
        correct = it.gold_label
        wrongs = [l for l in labels if l != correct]
        random.shuffle(wrongs)
        for neg in wrongs[:MAX_NEG_PER_ITEM]:
            rows.append({
                "prompt": prompt,
                "chosen": correct,
                "rejected": neg,
                "meta": {"dataset": "PubMedQA", "id": it.source_id, "type": "ynm"}
            })
    return rows

# ------------------------
# Write JSONL
# ------------------------
def write_jsonl(rows, path):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

# ------------------------
# Main Entry
# ------------------------
def main():
    medqa = load_medqa(MEDQA_PATH)
    medmcqa = load_medmcqa(MEDMCQA_PATH)
    pubmed = load_pubmedqa(PUBMEDQA_PATH)

    rows = []
    rows += build_pairs_medqa(medqa)
    rows += build_pairs_medmcqa(medmcqa)
    rows += build_pairs_pubmedqa(pubmed)

    random.shuffle(rows)
    write_jsonl(rows, OUTPUT_JSONL)
    print(f"\n Finished: wrote {len(rows)} DPO pairs to {OUTPUT_JSONL}")

if __name__ == "__main__":
    main()

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Loading MedQA: 100%|████████████████████████| 50/50 [00:00<00:00, 220752.84it/s]
Loading MedMCQA: 100%|██████████████████████| 50/50 [00:00<00:00, 225258.00it/s]
Loading PubMedQA: 100%|█████████████████████| 50/50 [00:00<00:00, 282254.64it/s]
Building MedQA pairs: 100%|███████████████████| 50/50 [00:00<00:00, 2233.77it/s]
Building MedMCQA pairs: 100%|████████████████| 50/50 [00:00<00:00, 16636.14it/s]
Building PubMedQA pairs: 100%|███████████████| 50/50 [00:00<00:00, 17470.44it/s]


✅ Finished: wrote 300 DPO pairs to dpo_pairs.jsonl



