In [1]:
import os

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from vllm import LLM, SamplingParams

# Custom libraries
from vqa_dataset import PromptDataset,prompt_collate,create_template

INFO 11-10 17:43:38 [__init__.py:216] Automatically detected platform cuda.


In [2]:
from huggingface_hub import login
login(token="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")

In [3]:
## User Input ##
#model   = "HuggingFaceM4/Idefics3-8B-Llama3"
model_name = "google/medgemma-4b-it"
task    = "classifcation"
save_every =50
options = True
out_dit = "out"
model_dir = "/pasteur/u/rdcunha/models"

## Envs:
os.environ["HF_HOME"] = model_dir
os.environ["TRANSFORMERS_CACHE"] = model_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = model_dir
os.environ["VLLM_CACHE_ROOT"]  = model_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Image.MAX_IMAGE_PIXELS = None

print(os.environ["HF_HOME"] )

/pasteur/u/rdcunha/models


# Define the model 

In [4]:
# pip install accelerate
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch


model = Gemma3ForConditionalGeneration.from_pretrained(
    model_name, device_map="auto",cache_dir=model_dir 
).eval()

processor = AutoProcessor.from_pretrained(model_name)



Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

NameError: name 'model_id' is not defined

In [6]:
def create_template(item):
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": item["image"]},
                {"type": "text", "text": item["question"]},
            ],
        }
    ]
    return conversation


# Define the task

In [7]:
task     = "all_cls"
mbu_root = f"/pasteur/u/rdcunha/data_cache/mmbu/final_data/VLMEvalData_v2/LMUData/{task}"

data_root= os.path.join(mbu_root, 'all_cls_closed_subsampled.tsv')

data = pd.read_csv(data_root,sep='\t')
filtered_ = data[data["question_type"] == "expert"]
df_ = filtered_[~filtered_["dataset"].isin([ "isic2018",'herlev',"breakhis_400x","breakhis_200x"])]

In [8]:
from torch.utils.data import Dataset, DataLoader
ds = PromptDataset(df=df_)
questions_data_loaders = DataLoader(ds, 
                                    batch_size=10, 
                                    shuffle=False,
                                    collate_fn=prompt_collate, 
                                    num_workers=10,
                                    persistent_workers=True, 
                                    pin_memory=True, 
                                    prefetch_factor=4)


In [9]:
for items in tqdm(questions_data_loaders, desc="Processing batches"):
    messages = [create_template(item) for item in items]
    inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
    input_len = inputs["input_ids"].shape[-1]
    
    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        outputs = []
        for g in generation:
            g = g[input_len:]
            decoded = processor.decode(g, skip_special_tokens=True)
            outputs.append(decoded)
    break



Processing batches:   0%|          | 0/735 [00:09<?, ?it/s]


# DO run:

In [None]:
import os
import json
from tqdm import tqdm

try:
    model_name = model_name.split('/')[1]
except:
    model_name = model_name
    
save_path = f"{model_name}_{task}_expert_closed.jsonl"
save_file = os.path.join(out_dit, save_path)

# --- Step 1: Collect already processed IDs ---
existing_ids = set()
if os.path.exists(save_file):
    with open(save_file, "r") as f:
        for line in f:
            try:
                data = json.loads(line)
                existing_ids.add(data["index"])
            except json.JSONDecodeError:
                continue  # skip corrupted lines

print(f"Found {len(existing_ids)} already processed items. Skipping them...")

# --- Step 2: Run inference only for new IDs ---
saved_items = []
counter = 0
sampling_params = SamplingParams(temperature=0, max_tokens=512)

with open(save_file, "a") as f:
    for items in tqdm(questions_data_loaders, desc="Processing batches"):
        # Filter out items whose IDs already exist
        new_items = [it for it in items if it["index"] not in existing_ids]
        if not new_items:
            continue  # nothing new in this batch

        ### THIS USUALLY NEEDS TO BE EDITED ###
        messages = [create_template(item) for item in items]
        inputs = processor.apply_chat_template( messages, add_generation_prompt=True, padding=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
        input_len = inputs["input_ids"].shape[-1]
        try:
            with torch.inference_mode():
                generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
                outputs = []
                for g in generation:
                    g = g[input_len:]
                    decoded = processor.decode(g, skip_special_tokens=True)
                    outputs.append(decoded)
        except: 
            print(f"could not generate for {items}")
            continue
        
            

            #####
       
        #except:
        #    print(f"could not generate for {items}")
        #    continue
       ### THIS USUALLY NEEDS TO BE EDITED ###

        for it, output in zip(new_items, outputs):
            answer = output
            saved_items.append({
                "index": it["index"],
                "question": it["question"],
                "options": it["options"],
                "image_path": it["image_path"],
                "image_scale": it["image_scale"],
                "scaled_width": it["scaled_width"],
                "scaled_height": it["scaled_height"],
                "dataset": it["dataset"],
                "class_label":it["class_label"],
                
                "answer": answer
            })
            existing_ids.add(it["index"])  # add to skip list in case of crash recovery
            counter += 1

            # Save every N examples
            if counter % save_every == 0:
                print(f"Saving at {save_file}")
                for s in saved_items:
                    f.write(json.dumps(s) + "\n")
                f.flush()
                saved_items = []

        #print("Could not run batch:",items)
    # Save remaining items
    if saved_items:
        for s in saved_items:
            f.write(json.dumps(s) + "\n")


Found 0 already processed items. Skipping them...


Processing batches:   1%| | 5/735 [00:41<1:40:41,  8.28s/it

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   1%| | 10/735 [01:22<1:39:37,  8.24s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   2%| | 15/735 [02:05<1:41:19,  8.44s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   3%| | 20/735 [02:45<1:38:20,  8.25s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   3%| | 25/735 [03:27<1:37:05,  8.20s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   4%| | 30/735 [04:07<1:35:32,  8.13s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   5%| | 35/735 [04:50<1:39:11,  8.50s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   5%| | 40/735 [05:33<1:39:04,  8.55s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   6%| | 45/735 [06:16<1:38:43,  8.59s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   7%| | 50/735 [06:59<1:39:41,  8.73s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   7%| | 55/735 [07:43<1:37:50,  8.63s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   8%| | 60/735 [08:26<1:36:41,  8.59s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:   9%| | 65/735 [09:04<1:29:46,  8.04s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  10%| | 70/735 [09:48<1:34:15,  8.51s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  10%| | 75/735 [10:29<1:32:16,  8.39s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  11%| | 80/735 [11:12<1:34:10,  8.63s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  12%| | 85/735 [11:54<1:31:41,  8.46s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  12%| | 90/735 [12:37<1:32:12,  8.58s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  13%|▏| 95/735 [13:14<1:17:52,  7.30s/i

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  14%|▏| 100/735 [13:53<1:24:25,  7.98s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  14%|▏| 105/735 [14:33<1:19:58,  7.62s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  15%|▏| 110/735 [15:11<1:22:36,  7.93s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  16%|▏| 115/735 [15:53<1:24:41,  8.20s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  16%|▏| 120/735 [16:34<1:24:09,  8.21s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  17%|▏| 125/735 [17:17<1:26:52,  8.55s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  18%|▏| 130/735 [18:00<1:26:26,  8.57s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  18%|▏| 135/735 [18:37<1:16:49,  7.68s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  19%|▏| 140/735 [19:18<1:19:27,  8.01s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  20%|▏| 145/735 [20:00<1:19:57,  8.13s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  20%|▏| 150/735 [20:42<1:20:00,  8.21s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  21%|▏| 155/735 [21:21<1:17:50,  8.05s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  22%|▏| 160/735 [22:04<1:21:38,  8.52s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  22%|▏| 165/735 [22:45<1:15:51,  7.99s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  23%|▏| 170/735 [23:26<1:18:55,  8.38s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  24%|▏| 175/735 [24:04<1:11:05,  7.62s/

Saving at out/medgemma-4b-it_all_cls_expert_closed.jsonl


Processing batches:  24%|▏| 179/735 [24:39<1:17:55,  8.41s/