In [1]:
# import modules
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import json

In [2]:
image_folder = "/home/mingi/data/vistext/data/test_images/horizontal"
question_prompt = "Values in the chart: "
save_file = "/home/mingi/data/chartgemma/results/horizontal.jsonl"

In [3]:
# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, processor):
        self.questions = questions
        self.image_folder = image_folder
        self.processor = processor
        self.file_names = os.listdir(image_folder)

    def __getitem__(self, index):
        image_file = self.file_names[index]

        image = Image.open(os.path.join(self.image_folder, image_file)).convert("RGB")

        inputs = self.processor(text=self.questions, images=image, return_tensors="pt")
        prompt_length = inputs["input_ids"].shape[1]
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        return inputs, prompt_length, image_file

    def __len__(self):
        return len(self.file_names)


def collate_fn(batch):
    inputs, prompt_length, image_file = zip(*batch)

    return inputs, prompt_length, image_file


def create_data_loader(
    questions,
    image_folder,
    processor,
    batch_size=1,
    # num_workers=4,
):
    assert batch_size == 1, "batch_size must be 1"
    dataset = CustomDataset(questions, image_folder, processor)
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        # num_workers=num_workers,
        shuffle=False,
        collate_fn=collate_fn,
    )
    return data_loader

In [4]:
model = PaliGemmaForConditionalGeneration.from_pretrained(
    "ahmed-masry/chartgemma", torch_dtype=torch.float16
).cuda()
processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")

os.makedirs(os.path.dirname(save_file), exist_ok=True)
ans_file = open(save_file, "w")

data_loader = create_data_loader(
    questions=question_prompt, image_folder=image_folder, processor=processor
)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
for inputs, prompt_length, image_file in tqdm(
    data_loader, total=len(os.listdir(image_folder))
):
    inputs = inputs[0]
    image_file = image_file[0]
    prompt_length = prompt_length[0]
    with torch.inference_mode():
        generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
    output_text = processor.batch_decode(
        generate_ids[:, prompt_length:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]

    ans_file.write(
        json.dumps(
            {
                "image_file": image_file,
                "text": output_text,
            }
        )
        + "\n"
    )
ans_file.close()

100%|██████████| 518/518 [1:08:29<00:00,  7.93s/it]
