In [None]:
!pip -q install -U transformers accelerate bitsandbytes datasets pandas

import torch, re, time
import torch.nn.functional as F
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

if device == "cuda":
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
    )
else:

    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

model.eval()
print("Loaded:", MODEL_NAME)


In [None]:
N = 200
ds = load_dataset("openlifescienceai/MedQA-USMLE-4-options-hf", split="train")

ds_small = ds.shuffle(seed=0).select(range(N))

def clean_text(x):
    if not isinstance(x, str): return x
    x = x.replace("\n", " ")
    return " ".join(x.split())

letters = ["A","B","C","D"]
option_cols = ["ending0","ending1","ending2","ending3"]

rows = []
for ex in ds_small:
    if any(ex.get(k) is None for k in ["sent1","sent2","ending0","ending1","ending2","ending3","label"]):
        continue
    q = clean_text(ex["sent1"]) + " " + clean_text(ex["sent2"])
    opts = [clean_text(ex[c]) for c in option_cols]
    lab = int(ex["label"])
    if lab not in [0,1,2,3]:
        continue
    rows.append({
        "question": q,
        "options": opts,
        "correct_letter": letters[lab],
    })

print("Usable rows:", len(rows))


In [None]:
def build_mcq_prompt(question_text, options):
    """
    Builds a clean MCQ prompt for Qwen-style instruction models.
    """
    return (
        "You are a medical expert answering USMLE-style questions.\n\n"
        "Question:\n"
        f"{question_text}\n\n"
        "Options:\n"
        f"A. {options[0]}\n"
        f"B. {options[1]}\n"
        f"C. {options[2]}\n"
        f"D. {options[3]}\n\n"
        "Give ONLY the single best option letter: A, B, C, or D.\n"
        "Answer:"
    )


In [None]:
@torch.no_grad()
def seq_logprob(prompt, completion_text):
    """
    Computes log P(completion_text | prompt) using teacher forcing.
    Works even when ' A' / ' B' tokenize into multiple tokens.
    """
    # tokenize prompt and completion separately
    p_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
    c_ids = tokenizer(completion_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)

    # concatenate
    input_ids = torch.cat([p_ids, c_ids], dim=1)

    # forward pass
    outputs = model(input_ids=input_ids)
    logits = outputs.logits[:, :-1, :]
    targets = input_ids[:, 1:]

    # indices corresponding to completion tokens
    start = p_ids.shape[1] - 1
    end = start + c_ids.shape[1]

    log_probs = torch.log_softmax(logits, dim=-1)
    token_logprobs = log_probs[0, start:end, :].gather(
        1, targets[0, start:end].unsqueeze(1)
    ).squeeze(1)

    return float(token_logprobs.sum().cpu())


def option_probs_from_model(prompt):
    """
    Returns P(A), P(B), P(C), P(D) as a numpy array.
    """
    completions = [" A", " B", " C", " D"]
    logps = torch.tensor(
        [seq_logprob(prompt, c) for c in completions],
        dtype=torch.float32
    )
    probs = torch.softmax(logps, dim=0).cpu().numpy()
    return probs


def predict_letter_from_probs(probs):
    letters = ["A", "B", "C", "D"]
    return letters[int(probs.argmax())]


In [None]:
@torch.no_grad()
def paraphrase_question_with_qwen(question_text, max_new_tokens=96):
    prompt = (
        "You are a medical doctor.\n\n"
        "Rewrite the following clinical question in different words while keeping "
        "ALL the medical meaning and clinical facts the same.\n"
        "Do NOT change the clinical facts. Do NOT add new information.\n"
        "Output ONLY the rewritten question, nothing else.\n\n"
        f"Original question:\n{question_text}\n\n"
        "Rewritten question:"
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
        eos_token_id=tokenizer.eos_token_id,
    )
    raw = tokenizer.decode(out_ids[0], skip_special_tokens=True)

    if raw.startswith(prompt):
        return raw[len(prompt):].strip()

    lines = [l.strip() for l in raw.split("\n") if l.strip()]
    return lines[-1] if lines else raw.strip()

records = []
t0 = time.time()

for i, r in enumerate(rows):
    q = r["question"]
    opts = r["options"]
    correct = r["correct_letter"]

    base_prompt = build_mcq_prompt(q, opts)
    base_probs = option_probs_from_model(base_prompt)
    base_pred = predict_letter_from_probs(base_probs)

    q_para = paraphrase_question_with_qwen(q)
    para_prompt = build_mcq_prompt(q_para, opts)
    para_probs = option_probs_from_model(para_prompt)
    para_pred = predict_letter_from_probs(para_probs)

    records.append({
        "qid": i,
        "question_original": q,
        "question_paraphrased": q_para,
        "option_A": opts[0], "option_B": opts[1], "option_C": opts[2], "option_D": opts[3],
        "correct_letter": correct,
        "baseline_pred": base_pred,
        "paraphrased_pred": para_pred,
        "baseline_correct": (base_pred == correct),
        "paraphrased_correct": (para_pred == correct),
        "prediction_flip": (base_pred != para_pred),
    })

    if (i+1) % 10 == 0:
        print(f"Paraphrase test: {i+1}/{len(rows)} done | {time.time()-t0:.1f}s")

para_df = pd.DataFrame(records)
para_df.to_csv("qwen_paraphrase_results_200.csv", index=False)
print("Saved: qwen_paraphrase_results_200.csv")
para_df.head()


In [None]:
@torch.no_grad()
def get_influential_words_lite(question_text, options, max_words=5, max_new_tokens=24):
    opt_block = (
        f"A. {options[0]}\n"
        f"B. {options[1]}\n"
        f"C. {options[2]}\n"
        f"D. {options[3]}"
    )
    ask_prompt = (
        "You are a medical doctor.\n\n"
        f"Question:\n{question_text}\n\n"
        f"Options:\n{opt_block}\n\n"
        "List the single most important words from the QUESTION that influenced your MCQ choice.\n"
        "Output ONLY a comma-separated list of individual words, e.g.\n"
        "fever,cough,chest,pain\n"
        "No sentences. No explanations.\n"
        "Words:"
    )
    inputs = tokenizer(ask_prompt, return_tensors="pt").to(device)
    out_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
        eos_token_id=tokenizer.eos_token_id,
    )
    raw_text = tokenizer.decode(out_ids[0], skip_special_tokens=True)

    completion = raw_text[len(ask_prompt):].strip() if raw_text.startswith(ask_prompt) else raw_text.split("\n")[-1].strip()

    cand_words = re.findall(r"[a-z]+", completion.lower())
    question_vocab = set(re.findall(r"[a-z]+", question_text.lower()))

    words = []
    for w in cand_words:
        if w in question_vocab and w not in words:
            words.append(w)
        if len(words) >= max_words:
            break
    return words

def redact_words(text, words):
    if not words:
        return text
    pattern = r"\b(" + "|".join(re.escape(w) for w in words) + r")\b"
    red = re.sub(pattern, "____", text, flags=re.IGNORECASE)
    return " ".join(red.split())

attrib_rows = []
t0 = time.time()

for i, r in enumerate(rows):
    q = r["question"]
    opts = r["options"]
    correct = r["correct_letter"]
    correct_idx = letters.index(correct)

    base_prompt = build_mcq_prompt(q, opts)
    base_probs = option_probs_from_model(base_prompt)
    base_pred = predict_letter_from_probs(base_probs)
    base_p_correct = float(base_probs[correct_idx])

    infl = get_influential_words_lite(q, opts, max_words=5)
    if len(infl) == 0:
        continue

    q_red = redact_words(q, infl)
    red_prompt = build_mcq_prompt(q_red, opts)
    red_probs = option_probs_from_model(red_prompt)
    red_pred = predict_letter_from_probs(red_probs)
    red_p_correct = float(red_probs[correct_idx])

    attrib_rows.append({
        "qid": i,
        "question": q,
        "question_redacted": q_red,
        "options": opts,
        "correct_letter": correct,
        "influential_words": infl,
        "num_inf": len(infl),
        "baseline_probs": base_probs.tolist(),
        "redacted_probs": red_probs.tolist(),
        "baseline_p_correct": base_p_correct,
        "redacted_p_correct": red_p_correct,
        "delta_p_correct": base_p_correct - red_p_correct,
        "baseline_pred": base_pred,
        "redacted_pred": red_pred,
        "pred_flip": (base_pred != red_pred),
    })

    if (i+1) % 10 == 0:
        print(f"Attribution test: {i+1}/{len(rows)} done | {time.time()-t0:.1f}s")

attrib_df = pd.DataFrame(attrib_rows)
attrib_df.to_csv("qwen_attribution_redaction_200.csv", index=False)
print("Saved: qwen_attribution_redaction_200.csv | rows:", len(attrib_df))
attrib_df.head()
