In [None]:
import os
import glob
import pandas as pd
from datasets import Dataset, Image as HFImage

def build_split_df(split_dir: str, normal_sub="NORMAL", pneu_sub="PNEUMONIA"):
    rows = []
    img_exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
    # NORMAL
    for p in glob.glob(os.path.join(split_dir, normal_sub, "*")):
        if os.path.isfile(p) and p.lower().endswith(img_exts):
            rows.append({"image_path": p, "label": "normal"})
    # PNEUMONIA
    for p in glob.glob(os.path.join(split_dir, pneu_sub, "*")):
        if os.path.isfile(p) and p.lower().endswith(img_exts):
            rows.append({"image_path": p, "label": "pneumonia"})
    df = pd.DataFrame(rows)
    if df.empty:
        raise RuntimeError(f"No images found under: {split_dir}")
    return df

def make_hf_dataset_from_dir(root_dir: str):
    """
    root_dir/
      train/
        NORMAL/
        PNEUMONIA/
      val/
        NORMAL/
        PNEUMONIA/
      test/
        NORMAL/
        PNEUMONIA/
    """
    splits = {}
    for split in ["train", "val", "test"]:
        split_path = os.path.join(root_dir, split)
        df = build_split_df(split_path)

        ds = Dataset.from_pandas(df, preserve_index=False)

        ds = ds.map(lambda ex: {"image": ex["image_path"]})

        ds = ds.cast_column("image", HFImage())

        splits[split] = ds

    return splits

# ÏòàÏãú ÏÇ¨Ïö©
data_root = "/workspace/kosombe2025/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray"  
train_ds, val_ds, test_ds = datasets_dict["train"], datasets_dict["val"], datasets_dict["test"]


In [None]:
import os
import re
import glob
import random
import pandas as pd
from typing import List

import torch
from datasets import Dataset, Image as HFImage
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from peft import LoraConfig, get_peft_model
from trl import GRPOTrainer, GRPOConfig

model_id  = "Qwen/Qwen2.5-VL-3B-Instruct"
cache_dir = "/workspace/huggingface/models/"  

SYSTEM_PROMPT = """Your task:
1. Think through the question step by step, and enclose your reasoning process inside <think>...</think> tags.
2. Then provide ONLY the final answer - either "yes" or "no" - inside <answer>...</answer> tags, written in lowercase letters.
3. Do not include anything else outside of these tags."""

QUESTION_POOL = [
    "Does this chest X-ray show pneumonia?",
    "Is there evidence of pneumonia in this X-ray?",
    "Does the scan indicate pneumonia?",
    "Can you see signs of pneumonia in this image?",
    "Is pneumonia present in the chest X-ray?",
    "Does this image suggest a pneumonia diagnosis?",
    "Is pneumonia detected in this chest X-ray image?",
    "Does the chest radiograph reveal pneumonia?",
]
def random_question() -> str:
    return random.choice(QUESTION_POOL)

def label_to_answer(label: str) -> str:
    # "pneumonia" -> "yes" / "normal" -> "no"
    return "yes" if str(label).lower() == "pneumonia" else "no"

In [None]:
import os
from peft import PeftModel

BASE_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"   
SFT_DIR = "/workspace/kosombe2025/MedVLM-SFT-Qwen2_5VL-YESNO_text_1012/final_model"            

processor = AutoProcessor.from_pretrained(
    BASE_MODEL_ID, cache_dir=cache_dir, use_fast=True, padding_side="left"
)
tok = processor.tokenizer
if tok.pad_token_id is None and tok.eos_token_id is not None:
    tok.pad_token = tok.eos_token  


def looks_like_lora_adapter(path: str) -> bool:
    return os.path.exists(os.path.join(path, "adapter_config.json"))

def load_sft_for_grpo():
    if looks_like_lora_adapter(SFT_DIR):
        base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            BASE_MODEL_ID,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            cache_dir=cache_dir,
        )
        model = PeftModel.from_pretrained(base, SFT_DIR, is_trainable=True)
        print('load completed LoRA adapter from SFT_DIR')
    else:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            SFT_DIR,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            cache_dir=cache_dir,
        )


    model.config.use_cache = False
    if getattr(model.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
        model.config.pad_token_id = tok.pad_token_id

    return model

model = load_sft_for_grpo()

if hasattr(model, "print_trainable_parameters"):
    model.print_trainable_parameters()
else:
    print("Loaded full model (no LoRA adapters).")

In [None]:
def to_chat_and_prompt(example):
    q = random_question()
    conversation = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": q},
            ],
        },
    ]
    prompt = processor.apply_chat_template(
        conversation, add_generation_prompt=True, tokenize=False
    )
    return {
        "prompt": prompt,
        "images": [example["image"]],
        "solutions": label_to_answer(example["label"]),
    }

def prepare_split_for_training(ds: Dataset):

    return ds.map(to_chat_and_prompt, remove_columns=ds.column_names)

In [None]:
import re
from typing import List

def format_reward(completions, **kwargs):
    pattern = r"^<think>[\s\S]*?</think>\s*<answer>\s*(?:yes|no)\s*</answer>\s*$"
    rewards = []
    for content in completions:
        ok = re.match(pattern, content.strip(), flags=re.IGNORECASE)
        rewards.append(0.2 if ok else 0.0)  
    return rewards

def _normalize(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    s = s.replace("yes.", "yes").replace("no.", "no")
    return s

# Ï†ïÌôïÎèÑ Î≥¥ÏÉÅ: ÎßûÏúºÎ©¥ ÌÅ¨Í≤å(1.0), ÏïÑÎãàÎ©¥ 0
def accuracy_reward(completions: List[str], solutions: List[str], **kwargs) -> List[float]:
    rewards = []
    for content, sol in zip(completions, solutions):
        sol_n = _normalize(sol)
        m = re.search(r"<answer>\s*(yes|no)\s*</answer>", content, flags=re.IGNORECASE)
        if not m:
            rewards.append(-0.1)
            continue
        ans = _normalize(m.group(1))
        reward = 1.0 if ans == sol_n else 0.0

        think = re.search(r"<think>([\s\S]*?)</think>", content, flags=re.IGNORECASE)
        if think:

            if len(think.group(1).split()) > 60:
                reward -= 0.05

        rewards.append(reward)
    return rewards

In [None]:
accuracy_reward(['<think>awonoqcoala aclalknnc </think><answer>no</answer>'],['no'])

In [None]:
USE_EXISTING_DATASETS = False  

if not USE_EXISTING_DATASETS:
    data_root = "/workspace/kosombe2025/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray"  
    datasets_dict = make_hf_dataset_from_dir(data_root)
    train_raw = datasets_dict["train"]
    val_raw   = datasets_dict["val"]
    test_raw  = datasets_dict["test"]
else:
    raise NotImplementedError("train_raw/val_raw/test_rawÎ•º Î∞îÏù∏Îî©ÌïòÏÑ∏Ïöî.")

# GRPOÏö© Ï†ÑÏ≤òÎ¶¨
train_ds = prepare_split_for_training(train_raw)
val_ds   = prepare_split_for_training(val_raw)  
test_ds  = prepare_split_for_training(test_raw)  

print(train_ds[0].keys())  # ['prompt', 'images', 'solutions']

print(train_ds)

In [None]:
from datasets import interleave_datasets
yes_ds = train_ds.filter(lambda ex: ex["solutions"].lower()=="yes").shuffle(42)
no_ds  = train_ds.filter(lambda ex: ex["solutions"].lower()=="no").shuffle(42)
train_ds_bal = interleave_datasets([yes_ds, no_ds], probabilities=[0.5, 0.5], seed=42)

In [None]:
print(train_ds[0]['prompt'])

In [None]:
train_ds_bal[2]

In [None]:
training_args = GRPOConfig(
    output_dir="MedVLM-R1-Qwen2.5-VL-3B-CXR-YESNO-1012-newconfig",
    num_train_epochs=3,
    per_device_train_batch_size=6,
    gradient_accumulation_steps=2,             
    num_generations=6,                             
    learning_rate=5e-6,                            
    bf16=True,
    max_prompt_length=None,                      
    max_completion_length=512,                  
    temperature=0.7,                             
    repetition_penalty=1.05,
    beta=0.02,                                  
    loss_type="dapo",                              
    importance_sampling_level="sequence",          
    mask_truncated_completions=True,               
    scale_rewards="none",                       
    top_entropy_quantile=0.2,                     
    remove_unused_columns=False,
    report_to=["tensorboard"],
    logging_steps=10,
    save_strategy="steps",
    save_steps=200,
    eval_strategy="no",
    gradient_checkpointing=True,
)

train_ds = train_ds.shuffle(seed=42)

trainer = GRPOTrainer(
    model=model,
    processing_class=processor,
    reward_funcs=[format_reward, accuracy_reward],
    args=training_args,
    train_dataset=train_ds,
)

trainer.train()

In [None]:
from peft import PeftModel

base_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"

adapter_path = "/workspace/kosombe2025/MedVLM-R1-Qwen2.5-VL-3B-CXR-YESNO-interleave-1011/checkpoint-2600"


cache_dir = "/workspace/huggingface/models/"

print("1. Î™®Îç∏ Î∞è ÌîÑÎ°úÏÑ∏ÏÑúÎ•º Î°úÎî©Ìï©ÎãàÎã§...")


processor = AutoProcessor.from_pretrained(base_model_id, cache_dir=cache_dir, use_fast=True)


base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    base_model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    cache_dir=cache_dir,
)


model = PeftModel.from_pretrained(base_model, adapter_path)
print("Î™®Îç∏ Î°úÎî© ÏôÑÎ£å!")

In [None]:
print("\n2. ÌÖåÏä§Ìä∏ Îç∞Ïù¥ÌÑ∞Î•º Ï§ÄÎπÑÌï©ÎãàÎã§...")


val_ds = val_ds.shuffle(seed=42) 
train_ds = train_ds.shuffle()
test_ds = test_ds.shuffle()

for i in range(10):
    ground_truth_answer = test_ds['solutions'][i]
    prompt = test_ds['prompt'][i]
    image = test_ds['images'][i][0]

    inputs = processor(text=prompt, images=[image], return_tensors="pt").to(model.device)

    print("\n4. Î™®Îç∏Ïùò ÎãµÎ≥Ä ÏÉùÏÑ±ÏùÑ ÏãúÏûëÌï©ÎãàÎã§...")
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=1024, 
            do_sample=False,    
        )


    output_ids = output_ids[:, inputs['input_ids'].shape[1]:]
    generated_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]

    print("ÏÉùÏÑ± ÏôÑÎ£å!")
    print(f"‚úÖ Ïã§Ï†ú Ï†ïÎãµ: {ground_truth_answer}")
    print(f"ü§ñ Î™®Îç∏ ÏÉùÏÑ± ÎãµÎ≥Ä:\n{generated_text.strip()}")
    print("--------------------")