In [1]:
# import modules
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

from PIL import Image
import requests
from transformers import (
    AutoProcessor,
    PaliGemmaForConditionalGeneration,
    AutoModelForPreTraining,
)
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import json
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
image_folder = "/home/mingi/data/vistext/data/simple_images/vertical"
tsv_folder = "/home/mingi/data/vistext/data/test_tsv"
save_file = "/home/mingi/data/v_llama/one_by_one_results/vertical.jsonl"
model_name = "vision_llama"
assert model_name in ["chartgemma", "vision_llama"]

In [3]:
# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, image_folder, tsv_folder, processor, model_name):
        self.image_folder = image_folder
        self.processor = processor
        self.tsv_folder = tsv_folder
        self.file_names = os.listdir(image_folder)
        self.model_name = model_name

    def __getitem__(self, index):
        # prepare all questions for one file
        image_file = self.file_names[index]
        tsv_file = image_file.split(".")[0] + ".tsv"

        image = Image.open(os.path.join(self.image_folder, image_file)).convert("RGB")
        tsv = pd.read_csv(os.path.join(self.tsv_folder, tsv_file), sep="\t")

        question_list = []
        value_list = []
        category_list = []
        for row in tsv.itertuples(index=False, name=None):
            row = list(row)
            if model_name == "chartgemma":
                question = f"What is the value of {row[0]}?\nAnswer the question using only number."
            elif model_name == "vision_llama":

                question = f"What is the value of {row[0]}?\nAnswer the question using only number."
                messages = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "image"},
                            {"type": "text", "text": question},
                        ],
                    }
                ]
                question = self.processor.apply_chat_template(
                    messages, add_generation_prompt=True
                )
            category_list.append(row[0])
            question_list.append(question)
            value_list.append(row[1])

        inputs = []
        prompt_lengths = []
        for question in question_list:
            if model_name == "chartgemma":
                input = self.processor(text=question, images=image, return_tensors="pt")
                input = {k: v.to("cuda") for k, v in input.items()}
            elif model_name == "vision_llama":
                input = self.processor(
                    image, question, add_special_tokens=False, return_tensors="pt"
                )

            inputs.append(input)

            prompt_length = input["input_ids"].shape[1]
            prompt_lengths.append(prompt_length)

        return inputs, prompt_lengths, image_file, value_list, category_list

    def __len__(self):
        return len(self.file_names)


def collate_fn(batch):
    inputs, prompt_lengths, image_file, value_list, category_list = zip(*batch)

    return inputs, prompt_lengths, image_file, value_list, category_list


def create_data_loader(
    image_folder,
    tsv_folder,
    processor,
    batch_size=1,
    # num_workers=4,
    model_name=None,
):
    assert batch_size == 1, "batch_size must be 1"
    dataset = CustomDataset(image_folder, tsv_folder, processor, model_name)
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        # num_workers=num_workers,
        shuffle=False,
        collate_fn=collate_fn,
    )
    return data_loader

In [4]:
if model_name == "chartgemma":
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        "ahmed-masry/chartgemma", torch_dtype=torch.float16
    ).cuda()
    processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
elif model_name == "vision_llama":
    model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

    model = AutoModelForPreTraining.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    ).cuda()
    processor = AutoProcessor.from_pretrained(model_id)


os.makedirs(os.path.dirname(save_file), exist_ok=True)
ans_file = open(save_file, "w")

data_loader = create_data_loader(
    image_folder=image_folder,
    tsv_folder=tsv_folder,
    processor=processor,
    model_name=model_name,
)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 5/5 [00:06<00:00,  1.29s/it]


In [5]:
for inputs, prompt_lengths, image_file, value_list, category_list in tqdm(
    data_loader, total=len(os.listdir(image_folder))
):
    inputs = inputs[0]
    image_file = image_file[0]
    prompt_lengths = prompt_lengths[0]
    value_list = value_list[0]
    category_list = category_list[0]

    output_text_list = []

    for input, prompt_length in zip(inputs, prompt_lengths):
        input = input.to(device="cuda", non_blocking=True)
        with torch.inference_mode():
            generate_ids = model.generate(**input, num_beams=4, max_new_tokens=30)
            output_text = processor.batch_decode(
                generate_ids[:, prompt_length:],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )[0]

        output_text_list.append(output_text)
        # print(output_text)

    ans_file.write(
        json.dumps(
            {
                "image_file": image_file,
                "text": output_text_list,
                "value_list": value_list,
                "category_list": category_list,
            }
        )
        + "\n"
    )
ans_file.close()

100%|██████████| 373/373 [33:48:58<00:00, 326.38s/it]   
