In [None]:
!pip install -q google-generativeai datasets pandas

import google.generativeai as genai
from datasets import load_dataset
import pandas as pd
import numpy as np
import time, re


In [None]:
#GEMINI CONFIG
API_KEY = "AIzaSyD143Azp4v6HKR4WLP2cHXDPWQjvneRMm4"
genai.configure(api_key=API_KEY)


MODEL_NAME = "gemini-2.5-flash-lite"
gemini_model = genai.GenerativeModel(MODEL_NAME)

#LOAD DATA
ds = load_dataset("openlifescienceai/MedQA-USMLE-4-options-hf", split="train")
df = ds.to_pandas()

# Build full question text
df["question"] = df["sent1"].str.strip() + " " + df["sent2"].str.strip()

option_cols = ["ending0", "ending1", "ending2", "ending3"]
letters = ["A", "B", "C", "D"]

len(df)


10178

In [None]:

RPM_LIMIT = 15
SAFETY_FACTOR = 0.8
MIN_INTERVAL = 60.0 / (RPM_LIMIT * SAFETY_FACTOR)

last_call_time = 0.0

def rate_limited_generate(prompt: str):
    """
    Wrapper around gemini_model.generate_content that enforces a minimum
    delay between requests so we don't exceed RPM.
    """
    global last_call_time
    now = time.time()
    wait = MIN_INTERVAL - (now - last_call_time)
    if wait > 0:
        time.sleep(wait)
    last_call_time = time.time()
    return gemini_model.generate_content(prompt)


In [None]:
def ask_gemini_mcq(question_text: str, options):
    """
    question_text: full question (no options embedded)
    options: [ending0, ending1, ending2, ending3] -> A,B,C,D
    returns: 'A'/'B'/'C'/'D' or None
    """
    prompt = f"""You are solving a 4-option USMLE-style medical multiple choice question.

Question:
{question_text}

Options:
A. {options[0]}
B. {options[1]}
C. {options[2]}
D. {options[3]}

Reply with ONLY the single letter of the correct choice: A, B, C, or D.
Do NOT include any explanation or extra text.
"""
    resp = rate_limited_generate(prompt)
    text = (resp.text or "").strip().upper()
    m = re.search(r"[ABCD]", text)
    return m.group(0) if m else None


In [None]:
def paraphrase_with_gemini(question: str) -> str:
    prompt = f"""Paraphrase the following USMLE-style medical MCQ question.

Rules:
1. Keep the original meaning exactly the same.
2. Do NOT solve the question.
3. Do NOT include answer choices.
4. Only rewrite the question text clearly and concisely.
5. Preserve medical terminology.

Question to paraphrase:
{question}
"""
    resp = rate_limited_generate(prompt)
    return (resp.text or "").strip()


In [None]:
# Drop obvious junk
df = df.dropna(subset=["sent1", "sent2", "ending0", "ending1", "ending2", "ending3", "label"])

#  Normalize whitespace
def clean(x):
    if not isinstance(x, str): return x
    x = x.replace("\n", " ")
    return " ".join(x.split())

for col in ["sent1", "sent2", "ending0", "ending1", "ending2", "ending3"]:
    df[col] = df[col].apply(clean)

df = df[df["label"].isin([0,1,2,3])]

In [None]:
#  BATCH CONFIG
TOTAL_QUESTIONS = len(df)

BATCH_SIZE = 50
BATCH_ID   = 7

SEED = 42
rng = np.random.default_rng(SEED)
all_indices = rng.permutation(TOTAL_QUESTIONS)

start = BATCH_ID * BATCH_SIZE
end   = min(start + BATCH_SIZE, TOTAL_QUESTIONS)

batch_indices = all_indices[start:end]
subset = df.iloc[batch_indices].copy().reset_index(drop=True)

print(f"Total questions: {TOTAL_QUESTIONS}")
print(f"Processing batch {BATCH_ID}: indices [{start}:{end}) -> size = {len(subset)}")


In [None]:
results = []
correct_orig = 0
correct_para = 0
consistant=0

N_SAMPLES = len(subset)

for i, (_, row) in enumerate(subset.iterrows(), start=1):
    question = row["question"]
    options = [row[c] for c in option_cols]
    true_label_idx = int(row["label"])
    true_letter = letters[true_label_idx]

    #  Gemini on ORIGINAL question + options
    pred_orig = ask_gemini_mcq(question, options)

    # Paraphrase question with Gemini
    para_question = paraphrase_with_gemini(question)

    #  Gemini on PARAPHRASED question + same options
    pred_para = ask_gemini_mcq(para_question, options)

    # Track accuracy
    if pred_orig == true_letter:
        correct_orig += 1
    if pred_para == true_letter:
        correct_para += 1
    if pred_orig==pred_para:
      consistant+=1



    results.append({
        "question": question,
        "paraphrased_question": para_question,
        "A": options[0],
        "B": options[1],
        "C": options[2],
        "D": options[3],
        "true_letter": true_letter,
        "pred_original": pred_orig,
        "pred_paraphrased": pred_para,
    })


    if i % 10 == 0 or i == N_SAMPLES:
        print(f"Processed {i}/{N_SAMPLES}")


In [None]:
orig_acc = correct_orig / N_SAMPLES
para_acc = correct_para / N_SAMPLES
consist=consistant/N_SAMPLES

print(f"\nBatch {BATCH_ID} â€“ Samples: {N_SAMPLES}")
print(f"Accuracy (original question):    {orig_acc:.3f}")
print(f"Accuracy (paraphrased question): {para_acc:.3f}")
print(f"Consistency : {consist:.3f}")

results_df = pd.DataFrame(results)
out_name = f"medqa_gemini_eval_batch_{BATCH_ID}.csv"
results_df.to_csv(out_name, index=False)
print("Saved:", out_name)
