In [None]:
!pip install -q google-generativeai datasets pandas tqdm
import google.generativeai as genai
from datasets import load_dataset
import pandas as pd
import numpy as np
import re, time, json
from tqdm import tqdm
from collections import deque

# CONFIGURE GEMINI
genai.configure(api_key="AIzaSyA89WbLWw_IT_On1sMdRWGwbqHFLLnAP_0")

import os, time
from collections import deque
import google.generativeai as genai


# Use Gemini 2.5 Flash-Lite
gemini_model = genai.GenerativeModel("gemini-2.0-flash")





In [None]:
# LOAD DATASET
ds = load_dataset("openlifescienceai/MedQA-USMLE-4-options-hf", split="train")
df = ds.to_pandas()

# Basic cleaning
df = df.dropna(subset=["sent1", "sent2", "ending0", "ending1", "ending2", "ending3", "label"])

def clean(x):
    if not isinstance(x, str):
        return x
    x = x.replace("\n", " ")
    return " ".join(x.split())

for col in ["sent1", "sent2", "ending0", "ending1", "ending2", "ending3"]:
    df[col] = df[col].apply(clean)

df = df[df["label"].isin([0,1,2,3])]
df["question"] = df["sent1"].str.strip() + " " + df["sent2"].str.strip()
df["correct_letter"] = df["label"].map({0:"A",1:"B",2:"C",3:"D"})

option_cols = ["ending0", "ending1", "ending2", "ending3"]
letters = ["A","B","C","D"]


In [None]:
def redact_words(text, words):
    """Replace full matches of each word with '____'."""
    if not words:
        return text
    pattern = r"\b(" + "|".join(re.escape(w) for w in words if w.strip()) + r")\b"
    red = re.sub(pattern, "____", text, flags=re.IGNORECASE)
    return " ".join(red.split())

def build_mcq_block(question_text, options):
    return (
        f"Question:\n{question_text}\n\n"
        "Options:\n"
        f"A. {options[0]}\n"
        f"B. {options[1]}\n"
        f"C. {options[2]}\n"
        f"D. {options[3]}\n"
    )


In [None]:
import numpy as np
import re, time

def gemini_option_probs(question_text, options, max_words=8, calls_per_min=14):
    """
    Drop-in compatible function with caller.
    Returns:
      1) probs vector over [A,B,C,D]
      2) model's chosen letter (A/B/C/D)
      3) influential reasoning words from question
    """


    mcq_block = build_mcq_block(question_text, options)

    prompt = f"""
You are a medical expert answering USMLE-style MCQs.

{mcq_block}

1) First list QUESTION words that helped you decide (space separated, SINGLE word tokens)
2) Then assign probabilities for options A/B/C/D in this EXACT format:

A: <prob>
B: <prob>
C: <prob>
D: <prob>
Choice: <A/B/C/D>

No JSON. No explanations. Only follow format.
"""

    # Call Gemini
    resp = gemini_model.generate_content(prompt)


    text = resp.text.strip()


    try:
        a = float(re.search(r"a:\s*([0-9.]+)", text, re.IGNORECASE).group(1))
        b = float(re.search(r"b:\s*([0-9.]+)", text, re.IGNORECASE).group(1))
        c = float(re.search(r"c:\s*([0-9.]+)", text, re.IGNORECASE).group(1))
        d = float(re.search(r"d:\s*([0-9.]+)", text, re.IGNORECASE).group(1))
        choice = re.search(r"choice:\s*([ABCD])", text, re.IGNORECASE).group(1).upper()
    except:
        print("[warn] parse failed for:", text[:200])
        return None, None, None

    vec = np.array([a, b, c, d], dtype=float)


    s = vec.sum()
    vec = vec/s if s>0 else np.ones(4)/4.0


    reason_section = text.split("\n")[0]
    cand = re.findall(r"[A-Za-z]+", reason_section)
    question_vocab = set(re.findall(r"[A-Za-z]+", question_text.lower()))
    inf_words = []
    for w in cand:
        lw = w.lower()
        if lw in question_vocab and lw not in inf_words:
            inf_words.append(lw)
        if len(inf_words) >= max_words:
            break

    return vec.tolist(), choice, inf_words




In [None]:
import numpy as np

def run_attribution_lite_gemini(row, max_influential=5):
    """
    For a single MedQA row, using Gemini:
      1) Get baseline self-reported probabilities over A/B/C/D
         + influential QUESTION words from that call
      2) Redact ALL those words at once in the question
      3) Recompute probabilities on the redacted question

    Returns a dict with scalar values, or None if something failed.
    """
    question_text = row["question"]
    options = [row[c] for c in option_cols]
    correct_letter = row["correct_letter"]

    # 1) Baseline probs + influential words (from the combined function)
    base_probs, base_choice, inf_words = gemini_option_probs(
        question_text,
        options,
        max_words=max_influential,
    )
    if base_probs is None:
        return None


    base_probs = np.array(base_probs, dtype=float)

    correct_idx = letters.index(correct_letter)
    base_p_correct = float(base_probs[correct_idx])
    base_pred_letter = base_choice if base_choice in letters else letters[int(base_probs.argmax())]


    if not inf_words:
        return None

    # Redact all influential words at once
    red_q = redact_words(question_text, inf_words)

    #  Recompute probs on redacted question
    red_probs, red_choice, _ = gemini_option_probs(
        red_q,
        options,
        max_words=0,
    )
    if red_probs is None:
        return None

    red_probs = np.array(red_probs, dtype=float)
    red_p_correct = float(red_probs[correct_idx])
    red_pred_letter = red_choice if red_choice in letters else letters[int(red_probs.argmax())]

    return {
        "question": question_text,
        "options": options,
        "correct_letter": correct_letter,

        "influential_words": inf_words,
        "num_inf": len(inf_words),

        "baseline_probs": base_probs.tolist(),
        "redacted_probs": red_probs.tolist(),

        "baseline_p_correct": base_p_correct,
        "redacted_p_correct": red_p_correct,
        "delta_p_correct": base_p_correct - red_p_correct,

        "baseline_pred": base_pred_letter,
        "redacted_pred": red_pred_letter,
    }


In [None]:
import pandas as pd
import time
from tqdm import tqdm

# CONFIG
START_IDX = 400
BATCH_SIZE =100
MAX_INF = 5


END_IDX = START_IDX + BATCH_SIZE
subset = df.iloc[START_IDX:END_IDX]
print(f"Processing df rows from {START_IDX} to {END_IDX-1} (total {len(subset)})")

rows = []
t0 = time.time()

for n, (df_idx, row) in enumerate(tqdm(subset.iterrows(), total=len(subset))):
    try:
        res = run_attribution_lite_gemini(row, max_influential=MAX_INF)
    except Exception as e:
        print(f"[error] df_idx {df_idx}: {e}")
        res = None

    if res is not None:

        res["df_idx"] = int(df_idx)
        rows.append(res)


    if (n + 1) % 10 == 0:
        elapsed = time.time() - t0
        print(f"Progress: {n+1}/{len(subset)} in this batch  |  {elapsed:.1f}s elapsed")
    time.sleep(7)

# build DataFrame and save
lite_df = pd.DataFrame(rows)
out_name = f"gemini_attribution_lite_rows_{START_IDX}_{END_IDX-1}.csv"
lite_df.to_csv(out_name, index=False)

total_time = time.time() - t0
print("\nSaved", out_name)
print("Valid rows:", len(lite_df))
print("Total time:", round(total_time, 1), "seconds")
