In [6]:
import os

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from vllm import LLM, SamplingParams

# Custom libraries
from vqa_dataset import PromptDataset,prompt_collate,create_template

In [7]:
from huggingface_hub import login
login(token="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")

In [8]:
## User Input ##
#model   = "HuggingFaceM4/Idefics3-8B-Llama3"
model_name = "lingshu-medical-mllm/Lingshu-32B"
task    = "classifcation"
save_every = 50
options = True
out_dit = "out"
model_dir = "/pasteur/u/rdcunha/models"

## Envs:
os.environ["HF_HOME"] = model_dir
os.environ["TRANSFORMERS_CACHE"] = model_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = model_dir
os.environ["VLLM_CACHE_ROOT"]  = model_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"

Image.MAX_IMAGE_PIXELS = None

print(os.environ["HF_HOME"] )

/pasteur/u/rdcunha/models


# Define the model 

In [9]:
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    device_map="auto",
    cache_dir=model_dir
)
processor = AutoProcessor.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


In [10]:
def create_template(item):
    message = [
    {
        "role":"user",
        "content":[
            {"type":"image","image":item["image_path"]},
            {"type":"text","text":item["question"]}
            ]
            }
    ]
    return message


# Define the task

In [11]:
task     = "all_cls"
mbu_root = f"/pasteur/u/rdcunha/data_cache/mmbu/final_data/VLMEvalData_v2/LMUData/{task}"

data_root= os.path.join(mbu_root, 'all_cls_closed_subsampled.tsv')

data = pd.read_csv(data_root,sep='\t')
filtered_ = data[data["question_type"] == "expert"]
df_ = filtered_[~filtered_["dataset"].isin([ "isic2018",'herlev',"breakhis_400x","breakhis_200x"])]

In [16]:
from torch.utils.data import Dataset, DataLoader
ds = PromptDataset(df=df_)
questions_data_loaders = DataLoader(ds, 
                                    batch_size=10, 
                                    shuffle=False,
                                    collate_fn=prompt_collate, 
                                    num_workers=2,
                                    persistent_workers=True, 
                                    pin_memory=True, 
                                    prefetch_factor=4)


In [18]:
import torch
from tqdm import tqdm

for items in tqdm(questions_data_loaders, desc="Processing batches"):
    # 1. Create messages for each batch item
    messages = [create_template(item) for item in items]
    
    # 2. Format text and multimodal inputs
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)

    # 3. Tokenize & prepare model inputs
    inputs = processor(
        text=text,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    # 4. Move tensors to GPU (not a specific device index)
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda", dtype=torch.bfloat16 if v.is_floating_point() else None)

    # 5. Run inference (use autocast for safety)
    with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
        generated_ids = model.generate(**inputs, max_new_tokens=128)

    # 6. Trim generated tokens
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
    ]

    # 7. Decode to text
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    break  # remove later for full run

Processing batches:   0%|         | 0/735 [00:16<?, ?it/s]


# DO run:

In [19]:
output_text

['C)', 'B)', 'B)', 'C)', 'A)', 'B)', 'B)', 'B)', 'B)', 'C)']

In [None]:
import os
import json
from tqdm import tqdm

try:
    model_name = model_name.split('/')[1]
except:
    model_name = model_name
    
save_path = f"{model_name}_{task}_expert_closed.jsonl"
save_file = os.path.join(out_dit, save_path)

# --- Step 1: Collect already processed IDs ---
existing_ids = set()
if os.path.exists(save_file):
    with open(save_file, "r") as f:
        for line in f:
            try:
                data = json.loads(line)
                existing_ids.add(data["index"])
            except json.JSONDecodeError:
                continue  # skip corrupted lines

print(f"Found {len(existing_ids)} already processed items. Skipping them...")

# --- Step 2: Run inference only for new IDs ---
saved_items = []
counter = 0
sampling_params = SamplingParams(temperature=0, max_tokens=512)

with open(save_file, "a") as f:
    for items in tqdm(questions_data_loaders, desc="Processing batches"):
        # Filter out items whose IDs already exist
        new_items = [it for it in items if it["index"] not in existing_ids]
        if not new_items:
            continue  # nothing new in this batch

        ### THIS USUALLY NEEDS TO BE EDITED ###
        try:
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
        
            inputs = processor(
                text=text,
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
        
            for k, v in inputs.items():
                if torch.is_tensor(v):
                    inputs[k] = v.to("cuda", dtype=torch.bfloat16 if v.is_floating_point() else None)
        
            with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
                generated_ids = model.generate(**inputs, max_new_tokens=128)
        
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
            ]
        
            outputs = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
        except:
            print(f"could not generate for {items}")
            continue
       ### THIS USUALLY NEEDS TO BE EDITED ###

        for it, output in zip(new_items, outputs):
            answer = output
            saved_items.append({
                "index": it["index"],
                "question": it["question"],
                "options": it["options"],
                "image_path": it["image_path"],
                "image_scale": it["image_scale"],
                "scaled_width": it["scaled_width"],
                "scaled_height": it["scaled_height"],
                "dataset": it["dataset"],
                "class_label":it["class_label"],
                
                "answer": answer
            })
            existing_ids.add(it["index"])  # add to skip list in case of crash recovery
            counter += 1

            # Save every N examples
            if counter % save_every == 0:
                print(f"Saving at {save_file}")
                for s in saved_items:
                    f.write(json.dumps(s) + "\n")
                f.flush()
                saved_items = []

        #print("Could not run batch:",items)
    # Save remaining items
    if saved_items:
        for s in saved_items:
            f.write(json.dumps(s) + "\n")
print("DONE")

Found 0 already processed items. Skipping them...


Processing batches:   1%|▏                  | 5/735 [01:21<3:17:38, 16.24s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   1%|▏                 | 10/735 [02:42<3:16:31, 16.26s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   2%|▎                 | 15/735 [04:03<3:14:35, 16.22s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   3%|▍                 | 20/735 [05:24<3:13:40, 16.25s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   3%|▌                 | 25/735 [06:46<3:11:59, 16.22s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   4%|▋                 | 30/735 [08:07<3:10:33, 16.22s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   5%|▊                 | 35/735 [09:28<3:09:57, 16.28s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   5%|▉                 | 40/735 [10:50<3:08:48, 16.30s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   6%|█                 | 45/735 [12:11<3:05:53, 16.16s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   7%|█▏                | 50/735 [13:36<3:16:40, 17.23s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   7%|█▎                | 55/735 [15:09<3:29:50, 18.52s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   8%|█▍                | 60/735 [16:43<3:30:43, 18.73s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:   9%|█▌                | 65/735 [18:17<3:29:09, 18.73s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  10%|█▋                | 70/735 [19:51<3:28:18, 18.79s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  10%|█▊                | 75/735 [21:25<3:26:48, 18.80s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  11%|█▉                | 80/735 [22:59<3:25:38, 18.84s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  12%|██                | 85/735 [24:33<3:23:55, 18.82s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  12%|██▏               | 90/735 [26:07<3:22:11, 18.81s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  13%|██▎               | 95/735 [27:41<3:20:13, 18.77s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  14%|██▎              | 100/735 [29:16<3:20:51, 18.98s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  14%|██▍              | 105/735 [30:50<3:17:27, 18.81s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  15%|██▌              | 110/735 [32:24<3:16:24, 18.86s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  16%|██▋              | 115/735 [33:59<3:15:07, 18.88s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  16%|██▊              | 120/735 [35:34<3:15:11, 19.04s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  17%|██▉              | 125/735 [37:09<3:13:05, 18.99s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  18%|███              | 130/735 [38:43<3:10:25, 18.88s/it]

Saving at out/Lingshu-32B_all_cls_expert_closed.jsonl


Processing batches:  18%|███              | 132/735 [39:21<3:09:27, 18.85s/it]