In [None]:
# Point to the folder that contains dataset_dict.json (NOT directly to train/)
VQARAD_PATH = "/content/drive/MyDrive/data/VQA_RAD"

# Do the same for SLAKE when you place it (adjust this to your actual path):
SLAKE_PATH  = "/content/drive/MyDrive/data/SLAKE"

OUT_DIR = "/content/drive/MyDrive/data/outputs_cfproxy"
!mkdir -p "$OUT_DIR"

In [None]:
from datasets import load_from_disk

# Load the SLAKE dataset (points to the folder containing dataset_dict.json)
slake = load_from_disk(SLAKE_PATH)   # returns a DatasetDict because dataset_dict.json exists
print(slake)                         # Expect: DatasetDict with 'train' and 'test'

# Pick a split
slake_train = slake["train"]
slake_test  = slake["test"]

# Quick sanity check
row = slake_test[0]
print(row.keys())       # -> dict_keys(['image','question','answer', ...])
print(row["question"], "->", row["answer"])
row["image"].show()

DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 1793
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 451
    })
})
dict_keys(['image', 'question', 'answer'])
is there evidence of an aortic aneurysm? -> yes


In [None]:
!pip -q install "transformers>=4.44" accelerate torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install datasets pillow opencv-python rapidfuzz tqdm matplotlib

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━[0m [32m2.2/3.2 MB[0m [31m84.4 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━[0m [32m2.5/3.2 MB[0m [31m38.2 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.2/3.2 MB[0m [31m32.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m27.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
SLAKE_PATH = "/content/drive/MyDrive/data/SLAKE"
OUT_DIR = "/content/drive/MyDrive/data/outputs_cfproxy"

from pathlib import Path
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

from datasets import load_from_disk
slake = load_from_disk(SLAKE_PATH)
ds_train, ds_test = slake["train"], slake["test"]

def iter_samples(ds, n=None):
    m = len(ds) if n is None else min(n, len(ds))
    for i in range(m):
        r = ds[i]
        yield i, r["image"], r["question"], r["answer"]

In [None]:
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"

processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32,
    trust_remote_code=True
).eval()



model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/429M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/272 [00:00<?, ?B/s]

In [None]:
def prompt_think_answer(q):
    return f'''You are a medical VQA assistant. Think carefully but return ONLY valid JSON:
{{"answer":"<short answer>"}}

Rules:
- Do not use code fences or backticks.
- No extra keys or text beyond the JSON.
- If your output is not valid JSON, immediately try again and output ONLY valid JSON.
Question: "{q}"'''

def prompt_caption_reason_answer(q):
    return f'''You are a medical VQA assistant. First DESCRIBE the image, then REASON, then ANSWER.
Return ONLY valid JSON with exactly these keys:
{{
 "caption":"<1-2 precise sentences about visible anatomy/findings>",
 "reasoning":["<step1>","<step2>","<step3>"],
 "boxes":[[x1,y1,x2,y2]],
 "answer":"<short answer>"
}}
Rules:
- Do not use code fences or backticks.
- No extra keys or text beyond the JSON.
- Use integers for box coordinates within image bounds.
- If your output is not valid JSON, immediately try again and output ONLY valid JSON.
- Output at most 1 box tightly enclosing the most diagnostic finding (avoid full-organ boxes).
- The box must be as small as possible while still covering the key evidence.

Question: "{q}"'''

In [None]:
import json, re

def parse_json_safe(text):
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return None
    chunk = m.group(0)
    # minimal repairs: strip trailing commas before } or ]
    chunk = re.sub(r",(\s*[}\]])", r"\1", chunk)
    try:
        return json.loads(chunk)
    except:
        return None


@torch.inference_mode()
def call_vlm(image_pil, prompt, max_new_tokens=384, temperature=0.2, top_p=0.9):
    # Messages with system role + image attached to the user turn
    messages = [
        {"role": "system", "content": "You are a helpful medical VQA assistant."},
        {"role": "user", "content": [
            {"type": "image"},                      # placeholder token in chat text
            {"type": "text", "text": prompt}
        ]}
    ]

    # 1) Get chat string (not tensors)
    chat_str = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # 2) Build tensors from (chat string + image)
    inputs = processor(
        text=[chat_str],
        images=[image_pil],
        return_tensors="pt"
    ).to(DEVICE)

    input_len = inputs["input_ids"].shape[-1]  # number of tokens in the prompt

    # 3) Generate
    output_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=(temperature > 0),
        temperature=temperature,
        top_p=top_p,
        pad_token_id=processor.tokenizer.eos_token_id,
        eos_token_id=processor.tokenizer.eos_token_id
    )

    # 4) Decode only the NEW tokens
    new_tokens = output_ids[0, input_len:]
    out = processor.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    return out  # parse_json_safe(out) should now find the JSON

In [None]:
i, img, q, gold = next(iter(iter_samples(ds_test, n=1)))
print("Q:", q, "| gold:", gold)

raw = call_vlm(img, prompt_caption_reason_answer(q))
print(raw[:500])            # should now show JSON or at least model text, not the system/user prelude
print(parse_json_safe(raw)) # should be a dict (or None if malformed)

Q: is there evidence of an aortic aneurysm? | gold: yes
```json
{
  "caption": "The image shows a chest X-ray with a pacemaker and no evidence of an aortic aneurysm.",
  "reasoning": [
    "The pacemaker is located in the chest cavity, not in the aorta.",
    "There are no visible signs of an aortic aneurysm such as bulging or enlargement of the aorta."
  ],
  "boxes": [
    [100, 100, 850, 900]
  ],
  "answer": "No, there is no evidence of an aortic aneurysm."
}
```
{'caption': 'The image shows a chest X-ray with a pacemaker and no evidence of an aortic aneurysm.', 'reasoning': ['The pacemaker is located in the chest cavity, not in the aorta.', 'There are no visible signs of an aortic aneurysm such as bulging or enlargement of the aorta.'], 'boxes': [[100, 100, 850, 900]], 'answer': 'No, there is no evidence of an aortic aneurysm.'}


In [None]:
# run_protocol(ds_test, f"{OUT_DIR}/vqarad_think_20.jsonl", prompt_think_answer, limit=20)
# run_protocol(ds_test, f"{OUT_DIR}/vqarad_cap_20.jsonl",   prompt_caption_reason_answer, limit=20)

100%|██████████| 20/20 [00:23<00:00,  1.16s/it]
100%|██████████| 20/20 [03:08<00:00,  9.45s/it]


In [None]:
# Inference for 200
run_protocol(ds_test, f"{OUT_DIR}/slake_think.jsonl", prompt_think_answer, limit=200)
run_protocol(ds_test, f"{OUT_DIR}/slake_cap.jsonl",   prompt_caption_reason_answer, limit=200)

100%|██████████| 200/200 [04:00<00:00,  1.20s/it]
100%|██████████| 200/200 [27:09<00:00,  8.15s/it]


In [None]:
import json, re
from rapidfuzz.fuzz import partial_ratio

# ---------- helpers ----------
def _get(d, *keys, default=None):
    for k in keys:
        if k in d: return d[k]
    return default

def has_json(r):
    return isinstance(r, dict) and isinstance(r.get("json"), dict)

def has_boxes(r):
    if not has_json(r):
        return False
    b = r["json"].get("boxes", None)
    return isinstance(b, list) and len(b) > 0

YN = {"y":"yes","yes":"yes","true":"yes","1":"yes",
      "n":"no","no":"no","false":"no","0":"no"}

def normalize_answer(a: str) -> str:
    if not a:
        return ""
    a = str(a).strip().lower()
    a = re.sub(r"\s+", " ", a)
    # unify yes/no style answers
    if a in YN:
        return YN[a]
    if a.startswith("y"):
        return "yes"
    if a.startswith("n"):
        return "no"
    # remove non-alphanum except space, dot, dash
    a = re.sub(r"[^a-z0-9 .-]", "", a)
    # collapse multiple spaces/dots
    a = re.sub(r"\s+", " ", a).strip()
    return a

def accuracy(records):
    ok=n=0
    for r in records:
        if not has_json(r):
            continue
        pred = normalize_answer(_get(r["json"], "answer", "pred", default=""))
        gold = normalize_answer(_get(r, "gold", "label", "answer", default=""))
        if pred and gold:
            ok += int(pred==gold); n += 1
    return ok/max(1,n)

# --------- caption–question consistency (for risk–coverage) ---------
# Expanded vocab for SLAKE/VQA-style clinical content
ORGANS = {
    "head","skull","brain","spine","cervical","thoracic","lumbar","rib","chest","lung",
    "heart","mediastinum","diaphragm","abdomen","liver","spleen","kidney","pancreas",
    "pelvis","hip","femur","knee","tibia","fibula","ankle","foot","humerus","elbow",
    "forearm","wrist","hand","shoulder","clavicle","sinus"
}
SIDES = {"left","right","bilateral","unilateral"}
FINDINGS = {
    "mass","nodule","lesion","fracture","dislocation","opacity","effusion","consolidation",
    "atelectasis","pneumothorax","pneumonia","edema","calcification","hemorrhage",
    "enlargement","dilation","hernia","stone","obstruction","metal","catheter","tube","line"
}

def norm(s):
    # lower, keep spaces/numbers/letters only (handles SLAKE en text; for zh text, semantic sim will dominate)
    return re.sub(r"[^a-z0-9 ]","", s.lower()) if s else ""

def consistency_score(q, cap):
    if not cap:
        return 0.0
    qn, cn = norm(q), norm(cap)
    # vocabulary coverage
    vocab = ORGANS | SIDES | FINDINGS
    q_terms = [t for t in vocab if t in qn]
    term_cov = sum(1 for t in q_terms if t in cn)/max(1,len(q_terms))
    # semantic string similarity (robust to wording)
    sem = partial_ratio(qn, cn)/100.0
    # weight vocab match higher; fall back to sem if no terms found
    alpha = 0.7 if q_terms else 0.3
    return alpha*term_cov + (1-alpha)*sem

def risk_coverage(records, tau):
    kept=[]
    for r in records:
        if not has_json(r):
            continue
        q = _get(r, "question", "Question", default="")
        cap = _get(r["json"], "caption", "Caption", default="")
        score = consistency_score(q, cap)
        if score >= tau:
            kept.append(r)
    cov = len(kept)/max(1,len(records))
    acc = accuracy(kept) if kept else 0.0
    return cov, acc, 1.0-acc

In [None]:
# Make sure you've already created these with run_protocol before running this cell:
# slake_think.jsonl, slake_cap.jsonl in OUT_DIR

think200 = [json.loads(x) for x in open(f"{OUT_DIR}/slake_think.jsonl",'r',encoding='utf-8')]
cap200   = [json.loads(x) for x in open(f"{OUT_DIR}/slake_cap.jsonl",'r',encoding='utf-8')]

print("Parsed:", sum(has_json(r) for r in cap200), "/", len(cap200))
print("Has boxes:", sum(has_boxes(r) for r in cap200), "/", len(cap200))
print("Acc Think→Answer (200):", accuracy(think200))
print("Acc Caption→Reason→Answer (200):", accuracy(cap200))

for t in [0.3,0.4,0.5,0.6]:
    cov, acc, risk = risk_coverage(cap200, t)
    print(f"Risk–Coverage τ={t:.1f}: coverage={cov:.2f}, acc={acc:.2f}, risk={risk:.2f}")

Parsed: 183 / 200
Has boxes: 151 / 200
Acc Think→Answer (200): 0.385
Acc Caption→Reason→Answer (200): 0.35911602209944754
Risk–Coverage τ=0.3: coverage=0.39, acc=0.29, risk=0.71
Risk–Coverage τ=0.4: coverage=0.39, acc=0.29, risk=0.71
Risk–Coverage τ=0.5: coverage=0.39, acc=0.29, risk=0.71
Risk–Coverage τ=0.6: coverage=0.36, acc=0.27, risk=0.73


In [None]:
# --- Counterfactuals on 200 (ROI-CF with shrink + inpaint) ---

import json, re, random
import numpy as np, cv2
from PIL import Image

random.seed(0); np.random.seed(0)

# assumes: normalize_answer, parse_json_safe, call_vlm, prompt_caption_reason_answer already defined

def answer_from_json(js):
    return normalize_answer((js or {}).get("answer",""))

def boxes_from_json(js):
    bx = (js or {}).get("boxes", [])
    return bx if isinstance(bx, list) else []

def box_area_frac(boxes, w, h):
    if not boxes: return 0.0
    x1,y1,x2,y2 = [int(v) for v in boxes[0]]
    x1 = max(0,min(w-1,x1)); x2 = max(0,min(w-1,x2))
    y1 = max(0,min(h-1,y1)); y2 = max(0,min(h-1,y2))
    A = max(0, x2-x1) * max(0, y2-y1)
    return A / max(1, w*h)

def shrink_box(b, w, h, frac=0.2):
    x1,y1,x2,y2 = [int(v) for v in b]
    cx, cy = (x1+x2)//2, (y1+y2)//2
    bw, bh = max(1, x2-x1), max(1, y2-y1)
    nx = max(1, int(bw*(1-frac)))
    ny = max(1, int(bh*(1-frac)))
    x1n, x2n = cx - nx//2, cx + nx//2
    y1n, y2n = cy - ny//2, cy + ny//2
    x1n = max(0, min(w-1, x1n)); x2n = max(0, min(w-1, x2n))
    y1n = max(0, min(h-1, y1n)); y2n = max(0, min(h-1, y2n))
    if x2n <= x1n or y2n <= y1n:
        return [x1, y1, x2, y2]
    return [x1n, y1n, x2n, y2n]

def shrink_boxes(boxes, w, h, frac=0.2):
    return [shrink_box(boxes[0], w, h, frac)] if boxes else []

def inpaint_boxes(image_pil, boxes):
    img = np.array(image_pil.convert("RGB"))
    h, w = img.shape[:2]
    mask = np.zeros((h,w), np.uint8)
    for b in boxes:
        x1,y1,x2,y2 = [int(v) for v in b]
        x1 = max(0, min(w-1, x1)); x2 = max(0, min(w-1, x2))
        y1 = max(0, min(h-1, y1)); y2 = max(0, min(h-1, y2))
        if x2>x1 and y2>y1:
            mask[y1:y2, x1:x2] = 255
    if mask.max()==0:
        return image_pil
    out = cv2.inpaint(img, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
    return Image.fromarray(out)

# load the 200-sample predictions and map dataset items (SLAKE)
cap200   = [json.loads(x) for x in open(f"{OUT_DIR}/slake_cap.jsonl",'r',encoding='utf-8')]
items200 = {i: ds_test[i] for i in range(200)}

cf_records = []
with_box_total = 0

for r in cap200:
    js = r.get("json")
    if not isinstance(js, dict):
        continue
    boxes = boxes_from_json(js)
    if not boxes:
        continue
    with_box_total += 1

    # r["idx"] should point to ds_test; fallback to position if absent
    idx = r.get("idx", None)
    if idx is None:
        continue

    item = items200.get(idx)
    if item is None:
        continue

    img, q = item["image"], item["question"]
    w, h = img.size

    # skip non-diagnostic huge ROIs
    if box_area_frac(boxes, w, h) > 0.70:
        continue

    # shrink + inpaint (no extra blur)
    boxes_s = shrink_boxes(boxes, w, h, frac=0.2)  # try 0.3 if still coarse
    img_cf  = inpaint_boxes(img, boxes_s)

    raw_cf = call_vlm(img_cf, prompt_caption_reason_answer(q))
    js_cf  = parse_json_safe(raw_cf)
    ans0   = answer_from_json(js)
    ans_cf = answer_from_json(js_cf)
    faithful = (bool(ans0) and bool(ans_cf) and (ans_cf != ans0))

    r.update({"json_cf": js_cf, "answer_cf": ans_cf, "faithful": faithful})
    cf_records.append(r)

cf_cov   = len(cf_records) / max(1, with_box_total)
flip_roi = sum(1 for r in cf_records if r.get("faithful") is True) / max(1, len(cf_records))

print(f"CF coverage (200, box-cases): {cf_cov:.2f}")
print(f"ROI-CF flip rate (200): {flip_roi:.2f}")

In [None]:
import numpy as np
np.random.seed(0)

# Ensure cap200 and items200 exist (SLAKE versions)
try:
    cap200
except NameError:
    import json
    cap200 = [json.loads(x) for x in open(f"{OUT_DIR}/slake_cap.jsonl",'r',encoding='utf-8')]

try:
    items200
except NameError:
    # map first 200 test items for quick lookup
    items200 = {i: ds_test[i] for i in range(200)}

def random_box_like(boxes, w, h):
    """Sample a random box with SAME size as the (possibly shrunk) first box."""
    if not boxes:
        return []
    x1,y1,x2,y2 = [int(v) for v in boxes[0]]
    bw, bh = max(1, x2-x1), max(1, y2-y1)
    rx1 = np.random.randint(0, max(1, w-bw))
    ry1 = np.random.randint(0, max(1, h-bh))
    return [[rx1, ry1, rx1+bw, ry1+bh]]

rand_flips, rand_total = 0, 0
for r in cap200:  # make sure cap200 is loaded
    js = r.get("json")
    if not isinstance(js, dict) or not boxes_from_json(js):
        continue

    # use items200 for the first 200 items
    idx = r.get("idx", None)
    if idx is None:
        continue
    item = items200.get(idx)
    if item is None:
        continue

    img, q = item["image"], item["question"]
    w, h = img.size
    boxes = boxes_from_json(js)

    # same skip rule as ROI-CF
    if box_area_frac(boxes, w, h) > 0.70:
        continue

    # use SAME shrink as ROI-CF to define a fair size
    boxes_s = shrink_boxes(boxes, w, h, frac=0.2)
    rboxes  = random_box_like(boxes_s, w, h)

    # use inpainting (same intervention as ROI-CF)
    img_cf = inpaint_boxes(img, rboxes)
    raw_cf = call_vlm(img_cf, prompt_caption_reason_answer(q))
    js_cf  = parse_json_safe(raw_cf)

    ans0   = answer_from_json(js)
    ans_cf = answer_from_json(js_cf)

    if bool(ans0) and bool(ans_cf):
        rand_flips += int(ans_cf != ans0)
        rand_total += 1

rand_flip_rate = rand_flips / max(1, rand_total)
print("Random-CF flip rate (200):", f"{rand_flip_rate:.2f}")