In [None]:

!pip install -q datasets pandas

from datasets import load_dataset
import pandas as pd
import numpy as np


ds = load_dataset(
    "openlifescienceai/MedQA-USMLE-4-options-hf",
    split="train"
)


df = ds.to_pandas()

# BASIC CLEANING

df = df.dropna(subset=["sent1", "sent2",
                       "ending0", "ending1", "ending2", "ending3",
                       "label"])

def clean_text(x):
    if not isinstance(x, str):
        return x
    x = x.replace("\n", " ")
    return " ".join(x.split())

for col in ["sent1", "sent2", "ending0", "ending1", "ending2", "ending3"]:
    df[col] = df[col].apply(clean_text)

df["label"] = df["label"].astype(int)
df = df[df["label"].isin([0, 1, 2, 3])]


df["question"] = df["sent1"].str.strip() + " " + df["sent2"].str.strip()


option_cols = ["ending0", "ending1", "ending2", "ending3"]
letters = ["A", "B", "C", "D"]

df["correct_letter"] = df["label"].apply(lambda i: letters[int(i)])
df["correct_text"]   = df.apply(
    lambda r: r[option_cols[int(r["label"])]],
    axis=1
)

print("Cleaned dataset size:", len(df))
df.head(3)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/735 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/5.12M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/648k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/667k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10178 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1272 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1273 [00:00<?, ? examples/s]

Cleaned dataset size: 10178


Unnamed: 0,id,sent1,sent2,ending0,ending1,ending2,ending3,label,question,correct_letter,correct_text
0,train-00000,A 23-year-old pregnant woman at 22 weeks gesta...,,Ampicillin,Ceftriaxone,Doxycycline,Nitrofurantoin,3,A 23-year-old pregnant woman at 22 weeks gesta...,D,Nitrofurantoin
1,train-00001,A 3-month-old baby died suddenly at night whil...,,Placing the infant in a supine position on a f...,Keeping the infant covered and maintaining a h...,Application of a device to maintain the sleepi...,Avoiding pacifier use during sleep,0,A 3-month-old baby died suddenly at night whil...,A,Placing the infant in a supine position on a f...
2,train-00002,A mother brings her 3-week-old infant to the p...,,Abnormal migration of ventral pancreatic bud,Complete failure of proximal duodenum to recan...,Abnormal hypertrophy of the pylorus,Failure of lateral body folds to move ventrall...,0,A mother brings her 3-week-old infant to the p...,A,Abnormal migration of ventral pancreatic bud


In [None]:
#  Login to Hugging Face
from huggingface_hub import login
login()

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model_name = "google/gemma-3-4b-it"

tokenizer = AutoTokenizer.from_pretrained(model_name)

if device == "cuda":
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
    ).to(device)

model.eval()

print("Gemma-3 4B model loaded ")



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Using device: cuda


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Gemma-3 4B model loaded ✅


In [None]:
option_cols = ["ending0", "ending1", "ending2", "ending3"]
letters = ["A", "B", "C", "D"]

def build_mcq_prompt(row):
    q = row["question"]
    opts = [row[c] for c in option_cols]
    prompt = f"""You are a medical expert answering USMLE-style questions.

Question:
{q}

Options:
A. {opts[0]}
B. {opts[1]}
C. {opts[2]}
D. {opts[3]}

Give ONLY the single best option letter: A, B, C, or D.
"""
    return prompt


In [None]:
import torch
import torch.nn.functional as F

letters = ["A", "B", "C", "D"]

def option_probs_from_model(prompt, letters=letters):
    """
    Return probabilities over [A, B, C, D] for the given prompt
    using Gemma's next-token logits.
    """
    # Tokenize prompt and send to device
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    # logits
    logits = outputs.logits


    next_token_logits = logits[0, -1, :]   =

    # Get token ids for " A", " B", " C", " D"
    letter_token_ids = [
        tokenizer(" " + L, add_special_tokens=False)["input_ids"][0]
        for L in letters
    ]

    # Extract the logits for these tokens
    scores = next_token_logits[letter_token_ids]

    scores = scores.to(torch.float32)
    probs = F.softmax(scores, dim=0).cpu().numpy()


    return probs


In [None]:
import re, torch

def get_influential_words_lite(question_text, options, max_words=5):
    """
    Fast version:
      - uses QUESTION + OPTIONS context
      - short generation
      - returns only words that appear in the QUESTION stem
    """


    assert len(options) == 4, "options must be a list of 4 strings [A,B,C,D]"

    opt_block = (
        f"A. {options[0]}\n"
        f"B. {options[1]}\n"
        f"C. {options[2]}\n"
        f"D. {options[3]}"
    )

    ask_prompt = (
        "You are a medical doctor.\n\n"
        f"Question:\n{question_text}\n\n"
        f"Options:\n{opt_block}\n\n"
        "List the single most important words from the QUESTION that influenced your MCQ choice.\n"
        "Output ONLY a comma-separated list of individual words, e.g.\n"
        "fever,cough,chest,pain\n"
        "No sentences. No explanations.\n"
        "Words:"
    )

    inputs = tokenizer(ask_prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=24,
            do_sample=False,
        )

    raw_text = tokenizer.decode(out_ids[0], skip_special_tokens=True)

    if raw_text.lower().startswith(ask_prompt.lower()):
        completion = raw_text[len(ask_prompt):].strip()
    else:
        completion = raw_text.split("\n")[-1].strip()

    cand_words = re.findall(r"[a-z]+", completion.lower())


    question_vocab = set(re.findall(r"[a-z]+", question_text.lower()))
    words = []
    for w in cand_words:
        if w in question_vocab and w not in words:
            words.append(w)
        if len(words) >= max_words:
            break

    return words


In [None]:
def redact_words(text, words):
    """
    Replace whole-word matches of any word in `words` with '____'.
    """
    if not words:
        return text
    pattern = r"\b(" + "|".join(re.escape(w) for w in words) + r")\b"
    red = re.sub(pattern, "____", text, flags=re.IGNORECASE)
    return " ".join(red.split())


In [None]:
import numpy as np
import re, torch

def run_attribution_lite(row, max_influential=5):
    """
    For a single row:
      1) baseline probs
      2) get influential words (QUESTION + OPTIONS context)
      3) redact ALL influential words at once and recompute probs

    Returns a dict with:
      - correct option
      - baseline and redacted predictions
      - baseline and redacted P(correct)
      - full prob vectors over [A,B,C,D]
      - influential words used for redaction
    """
    # 1) Baseline probabilities
    base_prompt = build_mcq_prompt(row)
    base_probs = option_probs_from_model(base_prompt, letters)  # shape (4,)
    correct_letter = row["correct_letter"]
    correct_idx = letters.index(correct_letter)

    base_p_correct = float(base_probs[correct_idx])
    base_pred_letter = letters[int(base_probs.argmax())]

    # 2) Extract influential words
    opts = [row[c] for c in option_cols]
    inf_source_words = get_influential_words_lite(
        question_text=row["question"],
        options=opts,
        max_words=max_influential
    )

    if len(inf_source_words) == 0:
        return None

    # 3) Redact ALL influential words at once
    red_q = redact_words(row["question"], inf_source_words)

    red_prompt = f"""You are a medical expert answering USMLE-style questions.

Question:
{red_q}

Options:
A. {opts[0]}
B. {opts[1]}
C. {opts[2]}
D. {opts[3]}

Give ONLY the single best option letter: A, B, C, or D.
"""

    # Recompute probabilities after redaction
    red_probs = option_probs_from_model(red_prompt, letters)
    red_p_correct = float(red_probs[correct_idx])
    red_pred_letter = letters[int(red_probs.argmax())]

    return {
        "question": row["question"],
        "options": opts,
        "correct_letter": correct_letter,

        "influential_words": inf_source_words,
        "num_inf": len(inf_source_words),

        "baseline_probs": base_probs.tolist(),
        "redacted_probs": red_probs.tolist(),

        "baseline_p_correct": base_p_correct,
        "redacted_p_correct": red_p_correct,
        "delta_p_correct": base_p_correct - red_p_correct,

        "baseline_pred": base_pred_letter,
        "redacted_pred": red_pred_letter,
    }


In [None]:
import pandas as pd
import time

N_QUESTIONS = 6000
MAX_INF = 5

subset = df.sample(N_QUESTIONS, random_state=0).reset_index(drop=True)

rows_collected = []
start_time = time.time()

for i, row in subset.iterrows():
    res = run_attribution_lite(row, max_influential=MAX_INF)
    if res is not None:
        res["qid"] = i
        rows_collected.append(res)


    if (i + 1) % 10 == 0:
        print(f"Processed {i+1}/{N_QUESTIONS}")

lite_df = pd.DataFrame(rows_collected)


lite_df.to_csv("attribution_lite_results.csv", index=False)

elapsed = time.time() - start_time
print("\nSaved attribution_lite_results.csv")
print("Total valid rows:", len(lite_df))
print("Time elapsed:", round(elapsed, 2), "seconds")


In [None]:
import pandas as pd


df = pd.read_csv("attribution_lite_results.csv")
same_pred_count = (df["baseline_pred"] == df["redacted_pred"]).sum()

total = len(df)

print(f"{same_pred_count}/{total}")