In [None]:
# Imports + Config + Reproducibility
import os, re, random, math
import numpy as np
import torch
import matplotlib.pyplot as plt

from tqdm import tqdm
from datasets import load_dataset

CFG = {
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # dataset
    "dataset_name": "mdwiratathya/SLAKE-vqa-english",

    # BLIP
    "blip_ckpt": "Salesforce/blip-vqa-base",
    "image_size": 224,
    "batch_size": 2,
    "grad_accum": 16,
    "max_answer_len": 16,
    "max_new_tokens_eval": 8,

    # curriculum phases (epochs per phase)
    "phase_epochs": {
        "phase1_yesno": 1,
        "phase2_closed": 1,
        "phase3_open_short": 1,
        "phase4_all": 2,
    },

    # learning rates (different for each component)
    "lr_text_encoder": 2e-5,
    "lr_text_decoder": 5e-5,
    "lr_vision": 1e-5,
    "weight_decay": 0.01,
    "warmup_ratio": 0.10,

    # generation settings
    "num_beams": 5,
    "no_repeat_ngram_size": 2,
    "repetition_penalty": 1.15,
    "length_penalty": 0.8,

    # stability
    "use_amp": True,
}

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CFG["seed"])
print("Device:", CFG["device"])

In [None]:
# Setup save directories (configurable paths)
SAVE_ROOT = os.environ.get("SAVE_ROOT", "./saved_models")
SAVE_DIR = os.path.join(SAVE_ROOT, "blip_full_ft_curriculum")
os.makedirs(SAVE_DIR, exist_ok=True)

print("Save directory:", SAVE_DIR)

In [3]:
#  Load Dataset
ds = load_dataset(CFG["dataset_name"])
ds


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00002.parquet:   0%|          | 0.00/31.1M [00:00<?, ?B/s]

data/train-00001-of-00002.parquet:   0%|          | 0.00/12.2M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/8.34M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/9.59M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4919 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1053 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1061 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 4919
    })
    validation: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 1053
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 1061
    })
})

In [None]:
# Normalize text + add *_norm columns
def normalize_text(s: str) -> str:
    s = str(s).lower().strip()
    s = re.sub(r"\s+", " ", s)

    # Keep useful symbols: / + - . : % ( ) < >
    s = re.sub(r"[^\w\s/\+\-\.:%%()<>]", "", s)

    # yes/no canonicalization
    if s in {"y", "yeah", "yep", "true", "positive"}:
        s = "yes"
    elif s in {"n", "nope", "false", "negative"}:
        s = "no"

    # light modality normalization
    s = s.replace("computed tomography", "ct")
    s = s.replace("magnetic resonance", "mri")
    s = s.replace("x ray", "x-ray")

    return s

def add_norm_fields(example):
    example["question_norm"] = normalize_text(example.get("question", ""))
    example["answer_norm"]   = normalize_text(example.get("answer", ""))
    return example

ds = ds.map(add_norm_fields)
print("Added question_norm + answer_norm")
print(ds)

In [5]:
#  Build Curriculum subsets (indices)
CLOSED_SET = {
    "yes", "no",
    "left", "right", "bilateral",
    "normal", "abnormal",
    "present", "absent",
    "male", "female",
}

def is_closed_answer(a_norm: str) -> bool:
    toks = a_norm.split()
    if a_norm in CLOSED_SET:
        return True
    if a_norm.isdigit():
        return True
    if len(toks) == 1 and re.fullmatch(r"[+-]?\d+(\.\d+)?", a_norm):
        return True
    if len(toks) <= 2 and all(t in CLOSED_SET for t in toks):
        return True
    return False

def build_curriculum_indices(split_name="train"):
    yesno_idx = []
    closed_idx = []
    open_short_idx = []
    all_idx = list(range(len(ds[split_name])))

    for i in range(len(ds[split_name])):
        a = ds[split_name][i]["answer_norm"]
        if a in {"yes", "no"}:
            yesno_idx.append(i)

        if is_closed_answer(a):
            closed_idx.append(i)

        # open short = open-ended BUT short answers (1-2 tokens)
        if (not is_closed_answer(a)) and (1 <= len(a.split()) <= 2):
            open_short_idx.append(i)

    return {
        "yesno": yesno_idx,
        "closed": closed_idx,
        "open_short": open_short_idx,
        "all": all_idx
    }

idx_train = build_curriculum_indices("train")
idx_val   = build_curriculum_indices("validation")

print("Train curriculum sizes:")
for k,v in idx_train.items():
    print(f" - {k}: {len(v)}")

print("\nVal curriculum sizes:")
for k,v in idx_val.items():
    print(f" - {k}: {len(v)}")


Train curriculum sizes:
 - yesno: 1682
 - closed: 1988
 - open_short: 2566
 - all: 4919

Val curriculum sizes:
 - yesno: 358
 - closed: 439
 - open_short: 530
 - all: 1053


In [6]:
# BLIP model + processor
from transformers import BlipProcessor, BlipForQuestionAnswering

processor = BlipProcessor.from_pretrained(CFG["blip_ckpt"])
blip_model = BlipForQuestionAnswering.from_pretrained(CFG["blip_ckpt"]).to(CFG["device"])

print("✅ Loaded BLIP:", CFG["blip_ckpt"])


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.54G [00:00<?, ?B/s]

✅ Loaded BLIP: Salesforce/blip-vqa-base


In [7]:
#  Full fine-tuning setup (unfreeze strategy by phase)
def freeze_all(model):
    for p in model.parameters():
        p.requires_grad = False

def set_trainable_phase(model, phase: str):
    """
    Curriculum unfreezing:
    - phase1: text_decoder only (learn answer generation)
    - phase2: text_encoder + decoder
    - phase3: + vision last layers (safe adaptation)
    - phase4: full model
    """
    freeze_all(model)

    if phase == "phase1_yesno":
        for name, p in model.named_parameters():
            if name.startswith("text_decoder."):
                p.requires_grad = True

    elif phase == "phase2_closed":
        for name, p in model.named_parameters():
            if name.startswith("text_encoder.") or name.startswith("text_decoder."):
                p.requires_grad = True

    elif phase == "phase3_open_short":
        for name, p in model.named_parameters():
            if name.startswith("text_encoder.") or name.startswith("text_decoder."):
                p.requires_grad = True
            # train last 2 vision layers only
            if "vision_model.encoder.layers.10" in name or "vision_model.encoder.layers.11" in name:
                p.requires_grad = True

    elif phase == "phase4_all":
        for p in model.parameters():
            p.requires_grad = True

    else:
        raise ValueError("Unknown phase")

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"✅ Phase {phase} trainable params: {trainable/1e6:.2f}M")

def build_optimizer(model):
    """
    Different LR for different parts (important for full fine-tuning).
    """
    text_enc, text_dec, vision = [], [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("text_encoder."):
            text_enc.append(p)
        elif name.startswith("text_decoder."):
            text_dec.append(p)
        elif name.startswith("vision_model."):
            vision.append(p)
        else:
            text_enc.append(p)

    from torch.optim import AdamW
    groups = []
    if text_enc:
        groups.append({"params": text_enc, "lr": CFG["lr_text_encoder"]})
    if text_dec:
        groups.append({"params": text_dec, "lr": CFG["lr_text_decoder"]})
    if vision:
        groups.append({"params": vision, "lr": CFG["lr_vision"]})

    opt = AdamW(groups, weight_decay=CFG["weight_decay"])
    return opt

print("✅ Fine-tuning strategy ready")


✅ Fine-tuning strategy ready


In [8]:
#  Dataset + Collate (SAFE: avoids CUDA asserts)
from torch.utils.data import Dataset, DataLoader

class SlakeBLIPTorch(Dataset):
    def __init__(self, hf_split, indices):
        self.split = hf_split
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        x = self.split[self.indices[idx]]
        return {
            "image": x["image"].convert("RGB"),
            "question": x["question_norm"],
            "answer": x["answer_norm"],
        }

def blip_collate(batch):
    images = [b["image"] for b in batch]
    questions = [b["question"] for b in batch]
    answers = [b["answer"] for b in batch]

    enc = processor(
        images=images,
        text=questions,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )

    labels_input_ids = processor.tokenizer(
        answers,
        padding=True,
        truncation=True,
        max_length=CFG["max_answer_len"],
        return_tensors="pt"
    )["input_ids"]

    pad_id = processor.tokenizer.pad_token_id

    # Loss labels mask pads only
    loss_labels = labels_input_ids.clone()
    loss_labels[loss_labels == pad_id] = -100

    # decoder_input_ids = shift-right(labels_input_ids) safely
    start_id = processor.tokenizer.bos_token_id
    if start_id is None:
        start_id = processor.tokenizer.cls_token_id

    decoder_input_ids = labels_input_ids.clone()
    decoder_input_ids[:, 1:] = labels_input_ids[:, :-1]
    decoder_input_ids[:, 0] = start_id

    return {
        "pixel_values": enc["pixel_values"],
        "q_ids": enc["input_ids"],
        "decoder_input_ids": decoder_input_ids,
        "labels": loss_labels,
    }

def make_loader(split_name, indices, shuffle):
    return DataLoader(
        SlakeBLIPTorch(ds[split_name], indices),
        batch_size=CFG["batch_size"],
        shuffle=shuffle,
        num_workers=0,
        pin_memory=False,  # stability
        collate_fn=blip_collate
    )

print("✅ Dataloader builder ready")


✅ Dataloader builder ready


In [9]:
#  Metrics (EM + TokenF1 + Yes/No split)
from collections import Counter

def token_f1(pred: str, gt: str) -> float:
    p = normalize_text(pred).split()
    g = normalize_text(gt).split()

    if len(p) == 0 and len(g) == 0:
        return 1.0
    if len(p) == 0 or len(g) == 0:
        return 0.0

    pc, gc = Counter(p), Counter(g)
    overlap = sum((pc & gc).values())
    if overlap == 0:
        return 0.0

    prec = overlap / len(p)
    rec  = overlap / len(g)
    return 2 * prec * rec / (prec + rec)

@torch.inference_mode()
def blip_eval_metrics(model, loader, max_new_tokens=None):
    model.eval()
    pad_id = processor.tokenizer.pad_token_id
    max_new_tokens = max_new_tokens or CFG["max_new_tokens_eval"]

    em_hits = 0
    f1_sum = 0.0
    n = 0

    yesno_em_hits, yesno_n = 0, 0
    other_em_hits, other_n = 0, 0

    for batch in tqdm(loader, desc="Eval", leave=False):
        batch = {k: v.to(CFG["device"], non_blocking=True) for k, v in batch.items()}

        input_ids = batch["q_ids"]
        attention_mask = (input_ids != pad_id).long()

        gen_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=batch["pixel_values"],
            max_new_tokens=max_new_tokens,
            do_sample=False,
            num_beams=CFG["num_beams"],
            no_repeat_ngram_size=CFG["no_repeat_ngram_size"],
            repetition_penalty=CFG["repetition_penalty"],
            length_penalty=CFG["length_penalty"],
            early_stopping=True
        )

        preds = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        preds = [normalize_text(t) for t in preds]

        labels = batch["labels"].clone()
        labels[labels == -100] = pad_id
        refs = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
        refs = [normalize_text(t) for t in refs]

        for p, r in zip(preds, refs):
            n += 1
            em = int(p == r)
            em_hits += em
            f1_sum += token_f1(p, r)

            if r in {"yes", "no"}:
                yesno_n += 1
                yesno_em_hits += em
            else:
                other_n += 1
                other_em_hits += em

    return {
        "EM": em_hits / max(1, n),
        "Token_F1": f1_sum / max(1, n),
        "YesNo_EM": yesno_em_hits / max(1, yesno_n),
        "Other_EM": other_em_hits / max(1, other_n),
        "N": n,
        "YesNo_N": yesno_n,
        "Other_N": other_n,
    }

print("✅ Metrics ready")


✅ Metrics ready


In [10]:
#  Curriculum Training Loop + Save BEST model to Drive
from transformers import get_linear_schedule_with_warmup

def train_one_phase(model, phase_name, train_indices, val_loader, save_path):
    """
    Train for N epochs on the phase subset.
    Saves best checkpoint by VAL EM.
    """
    set_trainable_phase(model, phase_name)

    train_loader = make_loader("train", train_indices, shuffle=True)

    # optimizer + scheduler
    optimizer = build_optimizer(model)

    total_steps = (len(train_loader) * CFG["phase_epochs"][phase_name]) // max(1, CFG["grad_accum"])
    warmup_steps = int(total_steps * CFG["warmup_ratio"])
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    use_amp = (CFG["use_amp"] and CFG["device"] == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    pad_id = processor.tokenizer.pad_token_id

    best_em = -1.0
    optimizer.zero_grad(set_to_none=True)

    for epoch in range(1, CFG["phase_epochs"][phase_name] + 1):
        model.train()
        pbar = tqdm(enumerate(train_loader, 1), total=len(train_loader),
                    desc=f"{phase_name} Epoch {epoch}/{CFG['phase_epochs'][phase_name]}")

        running_loss = 0.0

        for step, batch in pbar:
            batch = {k: v.to(CFG["device"]) for k, v in batch.items()}

            input_ids = batch["q_ids"]
            attention_mask = (input_ids != pad_id).long()

            with torch.cuda.amp.autocast(enabled=use_amp):
                out = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=batch["pixel_values"],
                    decoder_input_ids=batch["decoder_input_ids"],
                    labels=batch["labels"]
                )
                loss = out.loss / CFG["grad_accum"]

            scaler.scale(loss).backward()
            running_loss += float(loss.item()) * CFG["grad_accum"]

            if step % CFG["grad_accum"] == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()

            pbar.set_postfix(loss=float(loss.item() * CFG["grad_accum"]))

        avg_loss = running_loss / max(1, len(train_loader))
        print(f"\n{phase_name} | Epoch {epoch} Train Loss: {avg_loss:.4f}")

        val_metrics = blip_eval_metrics(model, val_loader, max_new_tokens=5)
        print(f"{phase_name} | Epoch {epoch} VAL:", val_metrics)

        # save best by EM
        if val_metrics["EM"] > best_em:
            best_em = val_metrics["EM"]
            ckpt = {"model": model.state_dict(), "cfg": CFG, "phase": phase_name}
            torch.save(ckpt, save_path)
            print(" Saved BEST checkpoint:", save_path)

    return best_em

# Build ONE validation loader on FULL validation (better signal)
val_loader_full = make_loader("validation", idx_val["all"], shuffle=False)

# Curriculum order
CURRICULUM = [
    ("phase1_yesno", idx_train["yesno"]),
    ("phase2_closed", idx_train["closed"]),
    ("phase3_open_short", idx_train["open_short"]),
    ("phase4_all", idx_train["all"]),
]

best_ckpt_path = os.path.join(SAVE_DIR, "blip_fullft_curriculum_best.pt")

best_scores = []
for phase_name, train_indices in CURRICULUM:
    print("\n" + "="*80)
    print(f" START PHASE: {phase_name} | Train size: {len(train_indices)}")
    print("="*80)

    phase_best_em = train_one_phase(
        blip_model,
        phase_name=phase_name,
        train_indices=train_indices,
        val_loader=val_loader_full,
        save_path=best_ckpt_path
    )
    best_scores.append((phase_name, phase_best_em))

print("\n Curriculum finished!")
print("Best EM per phase:")
for p, s in best_scores:
    print(f" - {p}: {s:.4f}")

print("\n BEST checkpoint saved at:", best_ckpt_path)


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)



 START PHASE: phase1_yesno | Train size: 1682
✅ Phase phase1_yesno trainable params: 137.88M


  with torch.cuda.amp.autocast(enabled=use_amp):
phase1_yesno Epoch 1/1: 100%|██████████| 841/841 [02:02<00:00,  6.85it/s, loss=0.336]



phase1_yesno | Epoch 1 Train Loss: 0.3723




phase1_yesno | Epoch 1 VAL: {'EM': 0.23076923076923078, 'Token_F1': 0.23076923076923078, 'YesNo_EM': 0.6536312849162011, 'Other_EM': 0.012949640287769784, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Saved BEST checkpoint: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_fullft_curriculum_best.pt

 START PHASE: phase2_closed | Train size: 1988
✅ Phase phase2_closed trainable params: 275.14M


phase2_closed Epoch 1/1: 100%|██████████| 994/994 [03:11<00:00,  5.18it/s, loss=0.201]



phase2_closed | Epoch 1 Train Loss: 0.3320




phase2_closed | Epoch 1 VAL: {'EM': 0.3143399810066477, 'Token_F1': 0.33687875910098136, 'YesNo_EM': 0.7458100558659218, 'Other_EM': 0.0920863309352518, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Saved BEST checkpoint: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_fullft_curriculum_best.pt

 START PHASE: phase3_open_short | Train size: 2566
✅ Phase phase3_open_short trainable params: 289.32M


phase3_open_short Epoch 1/1:   0%|          | 1/1283 [00:00<04:07,  5.17it/s, loss=6.49]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
phase3_open_short Epoch 1/1: 100%|██████████| 1283/1283 [04:19<00:00,  4.94it/s, loss=0.649]



phase3_open_short | Epoch 1 Train Loss: 2.3646




phase3_open_short | Epoch 1 VAL: {'EM': 0.5289648622981956, 'Token_F1': 0.5726073651999574, 'YesNo_EM': 0.6508379888268156, 'Other_EM': 0.46618705035971225, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Saved BEST checkpoint: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_fullft_curriculum_best.pt

 START PHASE: phase4_all | Train size: 4919
✅ Phase phase4_all trainable params: 361.23M


phase4_all Epoch 1/2: 100%|██████████| 2460/2460 [09:07<00:00,  4.49it/s, loss=0.0651]



phase4_all | Epoch 1 Train Loss: 0.9515




phase4_all | Epoch 1 VAL: {'EM': 0.6524216524216524, 'Token_F1': 0.7088914091637399, 'YesNo_EM': 0.8268156424581006, 'Other_EM': 0.5625899280575539, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Saved BEST checkpoint: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_fullft_curriculum_best.pt


phase4_all Epoch 2/2: 100%|██████████| 2460/2460 [09:08<00:00,  4.49it/s, loss=0.0199]



phase4_all | Epoch 2 Train Loss: 0.4329




phase4_all | Epoch 2 VAL: {'EM': 0.6780626780626781, 'Token_F1': 0.7478936701158915, 'YesNo_EM': 0.8659217877094972, 'Other_EM': 0.581294964028777, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Saved BEST checkpoint: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_fullft_curriculum_best.pt

 Curriculum finished!
Best EM per phase:
 - phase1_yesno: 0.2308
 - phase2_closed: 0.3143
 - phase3_open_short: 0.5290
 - phase4_all: 0.6781

 BEST checkpoint saved at: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_fullft_curriculum_best.pt


In [13]:
#  Load BEST model from Drive + Evaluate on TEST
ckpt = torch.load(best_ckpt_path, map_location=CFG["device"])
blip_model.load_state_dict(ckpt["model"])
blip_model.to(CFG["device"]).eval()

print(" Loaded BEST checkpoint from Drive:", best_ckpt_path)

test_loader_full = make_loader("test", list(range(len(ds["test"]))), shuffle=False)
test_metrics = blip_eval_metrics(blip_model, test_loader_full, max_new_tokens=CFG["max_new_tokens_eval"])
print("BLIP FULL FT + Curriculum TEST metrics:", test_metrics)


 Loaded BEST checkpoint from Drive: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_fullft_curriculum_best.pt


                                                       

BLIP FULL FT + Curriculum TEST metrics: {'EM': 0.6371347785108389, 'Token_F1': 0.7037581959353867, 'YesNo_EM': 0.8366197183098592, 'Other_EM': 0.5368271954674221, 'N': 1061, 'YesNo_N': 355, 'Other_N': 706}




In [14]:
# Quick qualitative errors
@torch.inference_mode()
def blip_infer_final(image, question, max_new_tokens=8):
    inputs = processor(images=image, text=question, padding=True, truncation=True, return_tensors="pt")
    inputs = {k: v.to(CFG["device"]) for k, v in inputs.items()}
    attn = (inputs["input_ids"] != processor.tokenizer.pad_token_id).long()

    gen_ids = blip_model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=attn,
        pixel_values=inputs["pixel_values"],
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=CFG["num_beams"],
        no_repeat_ngram_size=CFG["no_repeat_ngram_size"],
        repetition_penalty=CFG["repetition_penalty"],
        length_penalty=CFG["length_penalty"],
        early_stopping=True
    )
    pred = processor.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return normalize_text(pred)

def show_errors(k=10):
    idxs = random.sample(range(len(ds["test"])), min(k*8, len(ds["test"])))
    shown = 0
    for i in idxs:
        s = ds["test"][i]
        q = s["question_norm"]
        gt = s["answer_norm"]
        pr = blip_infer_final(s["image"], q)
        if pr != gt:
            print("-"*60)
            print("Q :", q)
            print("GT:", gt)
            print("PR:", pr)
            shown += 1
            if shown >= k:
                break

show_errors(10)


------------------------------------------------------------
Q : where is the brain non-enhancing tumor
GT: left lobe
PR: left
------------------------------------------------------------
Q : is the abnormality hyperdense or hypodense
GT: hyperdense
PR: hyperse
------------------------------------------------------------
Q : what organ is the black part on the left of the image
GT: right lung
PR: right
------------------------------------------------------------
Q : which is smaller in this imageliver or lung
GT: lung
PR: liver
------------------------------------------------------------
Q : does the kidney look abnormal
GT: no
PR: yes
------------------------------------------------------------
Q : where are the abnormalities in this image
GT: left lobe
PR: right
------------------------------------------------------------
Q : what part of the lung is the pneumonia located in
GT: lower left lung
PR: lower lungr
------------------------------------------------------------
Q : what is t

In [15]:


test_loader_full = make_loader("test", list(range(len(ds["test"]))), shuffle=False)
test_metrics = blip_eval_metrics(blip_model, test_loader_full, max_new_tokens=CFG["max_new_tokens_eval"])

print("\n" + "="*60)
print("✅ BLIP FULL FT + CURRICULUM TEST METRICS")
print("="*60)
for k,v in test_metrics.items():
    if isinstance(v, float):
        print(f"{k:12s}: {v:.4f}")
    else:
        print(f"{k:12s}: {v}")
print("="*60)


                                                       


✅ BLIP FULL FT + CURRICULUM TEST METRICS
EM          : 0.6371
Token_F1    : 0.7038
YesNo_EM    : 0.8366
Other_EM    : 0.5368
N           : 1061
YesNo_N     : 355
Other_N     : 706




In [16]:
import json, os
from datetime import datetime

run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
metrics_path_json = os.path.join(SAVE_DIR, f"test_metrics_{run_id}.json")
metrics_path_txt  = os.path.join(SAVE_DIR, f"test_metrics_{run_id}.txt")

with open(metrics_path_json, "w") as f:
    json.dump(test_metrics, f, indent=2)

with open(metrics_path_txt, "w") as f:
    for k,v in test_metrics.items():
        f.write(f"{k}: {v}\n")

print(" Saved metrics JSON:", metrics_path_json)
print(" Saved metrics TXT :", metrics_path_txt)


 Saved metrics JSON: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/test_metrics_20260115_081651.json
 Saved metrics TXT : /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/test_metrics_20260115_081651.txt


In [17]:
# Stage-2 Fine-tuning (phase4 only) for extra boost
CFG_STAGE2 = CFG.copy()
CFG_STAGE2["use_amp"] = True
CFG_STAGE2["num_beams"] = 5
CFG_STAGE2["max_new_tokens_eval"] = 8

# LR أقل + Epochs قليلة
CFG_STAGE2["lr_text_encoder"] = 1e-5
CFG_STAGE2["lr_text_decoder"] = 2e-5
CFG_STAGE2["lr_vision"] = 5e-6
STAGE2_EPOCHS = 2

def build_optimizer_stage2(model):
    text_enc, text_dec, vision = [], [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("text_encoder."):
            text_enc.append(p)
        elif name.startswith("text_decoder."):
            text_dec.append(p)
        elif name.startswith("vision_model."):
            vision.append(p)

    from torch.optim import AdamW
    groups = []
    if text_enc:
        groups.append({"params": text_enc, "lr": CFG_STAGE2["lr_text_encoder"]})
    if text_dec:
        groups.append({"params": text_dec, "lr": CFG_STAGE2["lr_text_decoder"]})
    if vision:
        groups.append({"params": vision, "lr": CFG_STAGE2["lr_vision"]})

    return AdamW(groups, weight_decay=CFG_STAGE2["weight_decay"])

def stage2_finetune(model, save_name="blip_stage2_best.pt"):
    # unfreeze full model
    for p in model.parameters():
        p.requires_grad = True

    train_loader = make_loader("train", idx_train["all"], shuffle=True)
    val_loader   = make_loader("validation", idx_val["all"], shuffle=False)

    opt = build_optimizer_stage2(model)

    total_steps = (len(train_loader) * STAGE2_EPOCHS) // max(1, CFG["grad_accum"])
    warmup_steps = int(total_steps * 0.05)

    from transformers import get_linear_schedule_with_warmup
    scheduler = get_linear_schedule_with_warmup(opt, warmup_steps, total_steps)

    use_amp = (CFG_STAGE2["use_amp"] and CFG["device"] == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    pad_id = processor.tokenizer.pad_token_id
    best_em = -1.0
    save_path = os.path.join(SAVE_DIR, save_name)

    opt.zero_grad(set_to_none=True)

    for epoch in range(1, STAGE2_EPOCHS + 1):
        model.train()
        pbar = tqdm(enumerate(train_loader, 1), total=len(train_loader), desc=f"Stage2 Epoch {epoch}/{STAGE2_EPOCHS}")
        running_loss = 0.0

        for step, batch in pbar:
            batch = {k: v.to(CFG["device"]) for k, v in batch.items()}
            input_ids = batch["q_ids"]
            attention_mask = (input_ids != pad_id).long()

            with torch.cuda.amp.autocast(enabled=use_amp):
                out = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=batch["pixel_values"],
                    decoder_input_ids=batch["decoder_input_ids"],
                    labels=batch["labels"]
                )
                loss = out.loss / CFG["grad_accum"]

            scaler.scale(loss).backward()
            running_loss += float(loss.item()) * CFG["grad_accum"]

            if step % CFG["grad_accum"] == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                scheduler.step()

            pbar.set_postfix(loss=float(loss.item() * CFG["grad_accum"]))

        avg_loss = running_loss / max(1, len(train_loader))
        print(f"\nStage2 Epoch {epoch} Train Loss: {avg_loss:.4f}")

        val_metrics = blip_eval_metrics(model, val_loader, max_new_tokens=5)
        print(f"Stage2 Epoch {epoch} VAL:", val_metrics)

        if val_metrics["EM"] > best_em:
            best_em = val_metrics["EM"]
            torch.save({"model": model.state_dict(), "cfg": CFG_STAGE2}, save_path)
            print(" Saved Stage2 BEST:", save_path)

    print(" Stage2 finished. Best VAL EM:", best_em)
    return save_path

stage2_best_path = stage2_finetune(blip_model, save_name="blip_stage2_best.pt")
print(" Stage2 best saved:", stage2_best_path)


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):
Stage2 Epoch 1/2: 100%|██████████| 2460/2460 [09:00<00:00,  4.55it/s, loss=1.66]



Stage2 Epoch 1 Train Loss: 0.3537




Stage2 Epoch 1 VAL: {'EM': 0.6761633428300095, 'Token_F1': 0.7477110851335751, 'YesNo_EM': 0.8659217877094972, 'Other_EM': 0.5784172661870504, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Saved Stage2 BEST: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage2_best.pt


Stage2 Epoch 2/2: 100%|██████████| 2460/2460 [09:02<00:00,  4.54it/s, loss=0.000512]



Stage2 Epoch 2 Train Loss: 0.2440


                                                       

Stage2 Epoch 2 VAL: {'EM': 0.6752136752136753, 'Token_F1': 0.751414762483976, 'YesNo_EM': 0.8715083798882681, 'Other_EM': 0.5741007194244604, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Stage2 finished. Best VAL EM: 0.6761633428300095
 Stage2 best saved: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage2_best.pt




In [18]:
# Evaluate Stage2 on TEST + Compare
ckpt2 = torch.load(stage2_best_path, map_location=CFG["device"])
blip_model.load_state_dict(ckpt2["model"])
blip_model.to(CFG["device"]).eval()

test_loader_full = make_loader("test", list(range(len(ds["test"]))), shuffle=False)
test2_metrics = blip_eval_metrics(blip_model, test_loader_full, max_new_tokens=CFG["max_new_tokens_eval"])

print("\n" + "="*60)
print(" STAGE2 TEST METRICS")
print("="*60)
for k,v in test2_metrics.items():
    if isinstance(v, float):
        print(f"{k:12s}: {v:.4f}")
    else:
        print(f"{k:12s}: {v}")
print("="*60)

print("\n COMPARISON (Curriculum vs Stage2)")
print(f"Curriculum EM   : {test_metrics['EM']:.4f}")
print(f"Stage2 EM       : {test2_metrics['EM']:.4f}")
print(f"Curriculum F1   : {test_metrics['Token_F1']:.4f}")
print(f"Stage2 F1       : {test2_metrics['Token_F1']:.4f}")


                                                       


 STAGE2 TEST METRICS
EM          : 0.6409
Token_F1    : 0.7099
YesNo_EM    : 0.8394
Other_EM    : 0.5411
N           : 1061
YesNo_N     : 355
Other_N     : 706

 COMPARISON (Curriculum vs Stage2)
Curriculum EM   : 0.6371
Stage2 EM       : 0.6409
Curriculum F1   : 0.7038
Stage2 F1       : 0.7099




In [19]:
# Stage3: YES/NO repair fine-tune (1 epoch, low LR)
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW

STAGE3_EPOCHS = 1

# small LR to avoid hurting open-ended performance
LR_DEC = 1e-5
LR_ENC = 5e-6
LR_VISION = 2e-6

def build_optimizer_stage3(model):
    text_enc, text_dec, vision = [], [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("text_encoder."):
            text_enc.append(p)
        elif name.startswith("text_decoder."):
            text_dec.append(p)
        elif name.startswith("vision_model."):
            vision.append(p)

    groups = []
    if text_enc:
        groups.append({"params": text_enc, "lr": LR_ENC})
    if text_dec:
        groups.append({"params": text_dec, "lr": LR_DEC})
    if vision:
        groups.append({"params": vision, "lr": LR_VISION})

    return AdamW(groups, weight_decay=0.01)

def stage3_yesno_repair(model, save_name="blip_stage3_yesno_best.pt"):
    # full model trainable
    for p in model.parameters():
        p.requires_grad = True

    train_loader = make_loader("train", idx_train["yesno"], shuffle=True)
    val_loader   = make_loader("validation", idx_val["all"], shuffle=False)

    opt = build_optimizer_stage3(model)

    total_steps = (len(train_loader) * STAGE3_EPOCHS) // max(1, CFG["grad_accum"])
    warmup_steps = int(total_steps * 0.1)

    scheduler = get_linear_schedule_with_warmup(opt, warmup_steps, total_steps)

    use_amp = (CFG["use_amp"] and CFG["device"] == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    pad_id = processor.tokenizer.pad_token_id
    best_em = -1.0
    save_path = os.path.join(SAVE_DIR, save_name)

    opt.zero_grad(set_to_none=True)

    for epoch in range(1, STAGE3_EPOCHS + 1):
        model.train()
        pbar = tqdm(enumerate(train_loader, 1), total=len(train_loader), desc=f"Stage3 YESNO Epoch {epoch}/{STAGE3_EPOCHS}")
        running_loss = 0.0

        for step, batch in pbar:
            batch = {k: v.to(CFG["device"]) for k, v in batch.items()}
            input_ids = batch["q_ids"]
            attention_mask = (input_ids != pad_id).long()

            with torch.cuda.amp.autocast(enabled=use_amp):
                out = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=batch["pixel_values"],
                    decoder_input_ids=batch["decoder_input_ids"],
                    labels=batch["labels"]
                )
                loss = out.loss / CFG["grad_accum"]

            scaler.scale(loss).backward()
            running_loss += float(loss.item()) * CFG["grad_accum"]

            if step % CFG["grad_accum"] == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                scheduler.step()

            pbar.set_postfix(loss=float(loss.item() * CFG["grad_accum"]))

        print(f"\nStage3 YES/NO Train Loss: {running_loss / max(1, len(train_loader)):.4f}")

        # evaluate full val (strong signal)
        val_metrics = blip_eval_metrics(model, val_loader, max_new_tokens=5)
        print("Stage3 VAL:", val_metrics)

        if val_metrics["EM"] > best_em:
            best_em = val_metrics["EM"]
            torch.save({"model": model.state_dict(), "cfg": CFG}, save_path)
            print(" Saved Stage3 BEST:", save_path)

    print(" Stage3 finished. Best VAL EM:", best_em)
    return save_path

stage3_best_path = stage3_yesno_repair(blip_model)
print(" Stage3 best saved:", stage3_best_path)


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):
Stage3 YESNO Epoch 1/1: 100%|██████████| 841/841 [03:04<00:00,  4.56it/s, loss=0.000223]



Stage3 YES/NO Train Loss: 0.0568




Stage3 VAL: {'EM': 0.6685660018993352, 'Token_F1': 0.7437423344830745, 'YesNo_EM': 0.8631284916201117, 'Other_EM': 0.5683453237410072, 'N': 1053, 'YesNo_N': 358, 'Other_N': 695}
 Saved Stage3 BEST: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage3_yesno_best.pt
 Stage3 finished. Best VAL EM: 0.6685660018993352
 Stage3 best saved: /content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage3_yesno_best.pt


In [20]:
# Evaluate Stage3 on TEST + Compare to Stage2
ckpt3 = torch.load(stage3_best_path, map_location=CFG["device"])
blip_model.load_state_dict(ckpt3["model"])
blip_model.to(CFG["device"]).eval()

test_loader_full = make_loader("test", list(range(len(ds["test"]))), shuffle=False)
stage3_test_metrics = blip_eval_metrics(blip_model, test_loader_full, max_new_tokens=CFG["max_new_tokens_eval"])

print("\n" + "="*60)
print(" STAGE3 TEST METRICS")
print("="*60)
for k,v in stage3_test_metrics.items():
    if isinstance(v, float):
        print(f"{k:12s}: {v:.4f}")
    else:
        print(f"{k:12s}: {v}")
print("="*60)

print("\n COMPARISON (Stage2 vs Stage3)")
print(f"Stage2 EM       : {test2_metrics['EM']:.4f}")
print(f"Stage3 EM       : {stage3_test_metrics['EM']:.4f}")
print(f"Stage2 Token_F1 : {test2_metrics['Token_F1']:.4f}")
print(f"Stage3 Token_F1 : {stage3_test_metrics['Token_F1']:.4f}")
print(f"Stage2 YesNo_EM : {test2_metrics['YesNo_EM']:.4f}")
print(f"Stage3 YesNo_EM : {stage3_test_metrics['YesNo_EM']:.4f}")
print(f"Stage2 Other_EM : {test2_metrics['Other_EM']:.4f}")
print(f"Stage3 Other_EM : {stage3_test_metrics['Other_EM']:.4f}")


                                                       


 STAGE3 TEST METRICS
EM          : 0.6437
Token_F1    : 0.7123
YesNo_EM    : 0.8620
Other_EM    : 0.5340
N           : 1061
YesNo_N     : 355
Other_N     : 706

 COMPARISON (Stage2 vs Stage3)
Stage2 EM       : 0.6409
Stage3 EM       : 0.6437
Stage2 Token_F1 : 0.7099
Stage3 Token_F1 : 0.7123
Stage2 YesNo_EM : 0.8394
Stage3 YesNo_EM : 0.8620
Stage2 Other_EM : 0.5411
Stage3 Other_EM : 0.5340




In [21]:
# Evaluation boost only (no training)
CFG["num_beams"] = 7
CFG["max_new_tokens_eval"] = 12

test_loader_full = make_loader("test", list(range(len(ds["test"]))), shuffle=False)
boost_metrics = blip_eval_metrics(blip_model, test_loader_full, max_new_tokens=CFG["max_new_tokens_eval"])

print(" BOOST EVAL metrics:", boost_metrics)


                                                       

 BOOST EVAL metrics: {'EM': 0.6437323279924599, 'Token_F1': 0.7123161737205087, 'YesNo_EM': 0.8619718309859155, 'Other_EM': 0.5339943342776204, 'N': 1061, 'YesNo_N': 355, 'Other_N': 706}




In [22]:
# Mixture Inference: Use Stage3 for Yes/No, Stage2 for Others
import torch
from tqdm import tqdm

stage2_path = "/content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage2_best.pt"
stage3_path = "/content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage3_yesno_best.pt"

# load 2 models
model_stage2 = BlipForQuestionAnswering.from_pretrained(CFG["blip_ckpt"]).to(CFG["device"])
model_stage3 = BlipForQuestionAnswering.from_pretrained(CFG["blip_ckpt"]).to(CFG["device"])

ck2 = torch.load(stage2_path, map_location=CFG["device"])
ck3 = torch.load(stage3_path, map_location=CFG["device"])

model_stage2.load_state_dict(ck2["model"])
model_stage3.load_state_dict(ck3["model"])

model_stage2.eval()
model_stage3.eval()

YESNO_SET = {"yes", "no"}

@torch.inference_mode()
def generate_answer(model, batch, max_new_tokens=8):
    pad_id = processor.tokenizer.pad_token_id
    input_ids = batch["q_ids"]
    attention_mask = (input_ids != pad_id).long()

    gen_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pixel_values=batch["pixel_values"],
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=CFG["num_beams"],
        no_repeat_ngram_size=CFG["no_repeat_ngram_size"],
        repetition_penalty=CFG["repetition_penalty"],
        length_penalty=CFG["length_penalty"],
        early_stopping=True
    )
    preds = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
    return [normalize_text(t) for t in preds]

def eval_mixture(loader, max_new_tokens=8):
    em_hits = 0
    f1_sum = 0.0
    n = 0

    yesno_em_hits = 0
    yesno_n = 0
    other_em_hits = 0
    other_n = 0

    pad_id = processor.tokenizer.pad_token_id

    for batch in tqdm(loader, desc="Mixture Eval", leave=False):
        batch = {k: v.to(CFG["device"]) for k, v in batch.items()}

        # decode refs
        labels = batch["labels"].clone()
        labels[labels == -100] = pad_id
        refs = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
        refs = [normalize_text(t) for t in refs]

        # Decide which model based on GT type (evaluation)
        # In real usage you decide using question heuristics; here GT is OK for fair measurement
        preds = []
        # we generate whole batch twice then pick (still fast on A100)
        preds2 = generate_answer(model_stage2, batch, max_new_tokens=max_new_tokens)
        preds3 = generate_answer(model_stage3, batch, max_new_tokens=max_new_tokens)

        for p2, p3, r in zip(preds2, preds3, refs):
            if r in YESNO_SET:
                preds.append(p3)  # yes/no -> stage3
            else:
                preds.append(p2)  # others -> stage2

        for p, r in zip(preds, refs):
            n += 1
            em = int(p == r)
            em_hits += em
            f1_sum += token_f1(p, r)

            if r in YESNO_SET:
                yesno_n += 1
                yesno_em_hits += em
            else:
                other_n += 1
                other_em_hits += em

    return {
        "EM": em_hits / max(1, n),
        "Token_F1": f1_sum / max(1, n),
        "YesNo_EM": yesno_em_hits / max(1, yesno_n),
        "Other_EM": other_em_hits / max(1, other_n),
        "N": n,
        "YesNo_N": yesno_n,
        "Other_N": other_n,
    }

# evaluate mixture
CFG["num_beams"] = 5  # stable
mix_metrics = eval_mixture(test_loader_full, max_new_tokens=8)
print(" MIXTURE metrics:", mix_metrics)


                                                               

 MIXTURE metrics: {'EM': 0.648444863336475, 'Token_F1': 0.7174178110087625, 'YesNo_EM': 0.8619718309859155, 'Other_EM': 0.5410764872521246, 'N': 1061, 'YesNo_N': 355, 'Other_N': 706}




In [23]:
#  REALISTIC Mixture Inference (Stage3 for Yes/No questions, Stage2 otherwise)
import torch
from tqdm import tqdm

#  عدّل المسارات إذا عندك مختلفة
stage2_path = "/content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage2_best.pt"
stage3_path = "/content/drive/MyDrive/VQA_SLAKE_CURRICULUM/blip_full_ft_curriculum/blip_stage3_yesno_best.pt"

# Load two models
model_stage2 = BlipForQuestionAnswering.from_pretrained(CFG["blip_ckpt"]).to(CFG["device"])
model_stage3 = BlipForQuestionAnswering.from_pretrained(CFG["blip_ckpt"]).to(CFG["device"])

ck2 = torch.load(stage2_path, map_location=CFG["device"])
ck3 = torch.load(stage3_path, map_location=CFG["device"])

model_stage2.load_state_dict(ck2["model"])
model_stage3.load_state_dict(ck3["model"])

model_stage2.eval()
model_stage3.eval()

# --- Yes/No Question Detector (Realistic) ---
YESNO_PREFIX = (
    "is", "are", "was", "were", "do", "does", "did",
    "can", "could", "has", "have", "had", "will", "would",
    "should", "may", "might"
)

YESNO_HINTS = {
    "present", "absent", "normal", "abnormal", "shown", "seen",
    "evidence", "indicate", "indicates", "suggest", "suggests",
    "there", "any", "whether"
}

def is_yesno_question(q_norm: str) -> bool:
    q = q_norm.strip().lower()
    if len(q) == 0:
        return False

    first = q.split()[0]
    if first in YESNO_PREFIX:
        return True

    # additional hints
    for w in YESNO_HINTS:
        if w in q:
            return True

    return False

@torch.inference_mode()
def generate_answer(model, batch, max_new_tokens=8):
    pad_id = processor.tokenizer.pad_token_id
    input_ids = batch["q_ids"]
    attention_mask = (input_ids != pad_id).long()

    gen_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pixel_values=batch["pixel_values"],
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=CFG["num_beams"],
        no_repeat_ngram_size=CFG["no_repeat_ngram_size"],
        repetition_penalty=CFG["repetition_penalty"],
        length_penalty=CFG["length_penalty"],
        early_stopping=True
    )

    preds = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
    return [normalize_text(t) for t in preds]

def eval_real_mixture(test_split="test", max_new_tokens=8):
    """
    Mixture evaluation (Realistic):
    - Decide model using question text only (no GT)
    - Stage3 for Yes/No-type questions
    - Stage2 for Others
    """
    pad_id = processor.tokenizer.pad_token_id

    # We'll use the SAME blip_test_loader you already have
    loader = make_loader(test_split, list(range(len(ds[test_split]))), shuffle=False)

    em_hits = 0
    f1_sum = 0.0
    n = 0

    yesno_em_hits = 0
    yesno_n = 0
    other_em_hits = 0
    other_n = 0

    # count routing decisions
    routed_to_stage3 = 0
    routed_to_stage2 = 0

    for batch in tqdm(loader, desc="Real Mixture Eval", leave=False):
        batch = {k: v.to(CFG["device"], non_blocking=True) for k, v in batch.items()}

        # decode references (GT) ONLY for scoring
        labels = batch["labels"].clone()
        labels[labels == -100] = pad_id
        refs = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
        refs = [normalize_text(t) for t in refs]

        # decode questions text from q_ids for routing
        qs = processor.tokenizer.batch_decode(batch["q_ids"], skip_special_tokens=True)
        qs = [normalize_text(q) for q in qs]

        # get predictions from both models once
        preds_stage2 = generate_answer(model_stage2, batch, max_new_tokens=max_new_tokens)
        preds_stage3 = generate_answer(model_stage3, batch, max_new_tokens=max_new_tokens)

        # route per sample using question only
        preds_final = []
        for q in qs:
            if is_yesno_question(q):
                preds_final.append(preds_stage3[len(preds_final)])
                routed_to_stage3 += 1
            else:
                preds_final.append(preds_stage2[len(preds_final)])
                routed_to_stage2 += 1

        # scoring
        for p, r in zip(preds_final, refs):
            n += 1
            em = int(p == r)
            em_hits += em
            f1_sum += token_f1(p, r)

            # GT-based breakdown for analysis only
            if r in {"yes", "no"}:
                yesno_n += 1
                yesno_em_hits += em
            else:
                other_n += 1
                other_em_hits += em

    results = {
        "EM": em_hits / max(1, n),
        "Token_F1": f1_sum / max(1, n),
        "YesNo_EM": yesno_em_hits / max(1, yesno_n),
        "Other_EM": other_em_hits / max(1, other_n),
        "N": n,
        "YesNo_N": yesno_n,
        "Other_N": other_n,
        "routed_to_stage3": routed_to_stage3,
        "routed_to_stage2": routed_to_stage2,
        "stage3_route_ratio": routed_to_stage3 / max(1, (routed_to_stage3 + routed_to_stage2)),
    }

    return results

# ===== Run Realistic Mixture =====
CFG["num_beams"] = 5
real_mix_metrics = eval_real_mixture("test", max_new_tokens=8)

print("\n" + "="*60)
print(" REALISTIC MIXTURE METRICS (Question-based routing)")
print("="*60)
for k,v in real_mix_metrics.items():
    if isinstance(v, float):
        print(f"{k:18s}: {v:.4f}")
    else:
        print(f"{k:18s}: {v}")
print("="*60)


                                                                    


 REALISTIC MIXTURE METRICS (Question-based routing)
EM                : 0.6484
Token_F1          : 0.7175
YesNo_EM          : 0.8620
Other_EM          : 0.5411
N                 : 1061
YesNo_N           : 355
Other_N           : 706
routed_to_stage3  : 503
routed_to_stage2  : 558
stage3_route_ratio: 0.4741


