In [1]:
# import modules
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch

import matplotlib.pyplot as plt
import re

In [2]:
data_path = "/home/mingi/data/vistext/data/test_images/vertical"
file_names = sorted(
    os.listdir(data_path), key=lambda x: int(re.search(r"\d+", x).group())
)

In [3]:
model = PaliGemmaForConditionalGeneration.from_pretrained(
    "ahmed-masry/chartgemma", torch_dtype=torch.float16
).cuda()
processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
# Format text inputs



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
def inference_plot(model, processor, input_text, data_path, file_names, idx):

    # Encode chart figure and tokenize text
    image = Image.open(os.path.join(data_path, file_names[idx])).convert("RGB")
    inputs = processor(text=input_text, images=image, return_tensors="pt")
    prompt_length = inputs["input_ids"].shape[1]
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    # Generate
    generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
    output_text = processor.batch_decode(
        generate_ids[:, prompt_length:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]
    print(output_text)

    plt.imshow(image)
    plt.axis("off")
    plt.show()


def inference_plot2(model, processor, input_text, data_path, image_name):
    path = "/home/mingi/data/test_charts/img"

    # Encode chart figure and tokenize text
    image = Image.open(os.path.join(path, image_name)).convert("RGB")
    inputs = processor(text=input_text, images=image, return_tensors="pt")
    prompt_length = inputs["input_ids"].shape[1]
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    # Generate
    generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
    output_text = processor.batch_decode(
        generate_ids[:, prompt_length:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]
    print(output_text)

    # plt.imshow(image)
    # plt.axis("off")
    # plt.show()

In [18]:
# Define the path to your txt file
file_path = "ann.txt"

# Initialize an empty list to store the dictionaries
data_list = []

# Open and read the file
with open(file_path, "r", encoding="utf-8") as file:
    # Initialize an empty dictionary to temporarily store each item's data
    current_dict = {}
    query_lines = []  # List to accumulate query lines

    # Loop through each line in the file
    for line in file:
        # Remove leading/trailing whitespace
        line = line.strip()

        if line.startswith("- imagename:"):
            # Append any previous item to the list
            if current_dict:
                # Join the accumulated query lines
                current_dict["query"] = " ".join(query_lines).strip()
                data_list.append(current_dict)

            # Start a new dictionary for the next item
            current_dict = {"imagename": line.split(": ")[1].strip()}
            query_lines = []  # Reset the query lines list

        elif line.startswith("label:"):
            current_dict["label"] = line.split(": ")[1].strip().strip("'")

        elif line.startswith("query:"):
            # Start a new query line accumulation, remove "query:" part
            query_lines = [line.split(": ", 1)[1].strip()]

        else:
            # For continuation of a query across multiple lines
            query_lines.append(line)

    # Append the last dictionary if it exists
    if current_dict:
        current_dict["query"] = " ".join(query_lines).strip()
        data_list.append(current_dict)

# Print the list of dictionaries
print(data_list)

[{'imagename': '30.png', 'label': '50', 'query': 'What was the average ticket price for Nashville Predators games in the 2011/12 season?'}, {'imagename': '30_rot90.png', 'label': '50', 'query': 'What was the average ticket price for Nashville Predators games in the 2011/12 season?'}, {'imagename': '30.png', 'label': '25', 'query': 'How much did the ticket price increase from the 2005/06 season to the 2014/15 season?'}, {'imagename': '30_rot90.png', 'label': '25', 'query': 'How much did the ticket price increase from the 2005/06 season to the 2014/15 season?'}, {'imagename': '30.png', 'label': '2012/13', 'query': 'In which season did the average ticket price first reach 60 U.S. dollars?'}, {'imagename': '30_rot90.png', 'label': '2012/13', 'query': 'In which season did the average ticket price first reach 60 U.S. dollars?'}, {'imagename': '78.png', 'label': 'Singapore', 'query': 'Which country had the highest trade value with India in FY 2019?'}, {'imagename': '78_rot90.png', 'label': 'S

In [20]:
input_text = "Make a table that corresponds to this chart."
input_text = "Can you extract the data points from the chart in the image and convert them into a JSON format? : {"
input_text = "What was the average ticket price for Nashville Predators games in the 2011/12 season?"
# input_text = "convert the top of chart to a table"
for data in data_list:
    image_name = data["imagename"]
    label = data["label"]
    query = data["query"]
    inference_plot2(model, processor, query, data_path, image_name)

The average ticket price for Nashville Predators games in the 2011/12 season was around 50 US dollars.
The average ticket price for Nashville Predators games in the 2011/12 season was approximately 50 US dollars.
The average ticket price increased by approximately 10 US dollars from the 2005/06 season to the 2014/15 season.
The ticket price increased by approximately 15 dollars from the 2005/06 season to the 2014/15 season.
The average ticket price first reached 60 U.S. dollars in the 2013-2014 season.
The average ticket price first reached 60 U.S. dollars in the 2013/14 season.
Singapore
Singapore
Approximately 22,000 million U.S. dollars.
The approximate trade value with Indonesia in FY 2019 was 22,000 million U.S. dollars.
7
4 countries had a trade value below 10,000 million U.S. dollars in FY 2019: Myanmar, Philippines, Vietnam, and Cambodia.
2014
The number of employees first reached 4,500 in 2015.
The lowest number of employees was around 3,500 in 2009 and 2010.
Approximately 2,0

In [None]:
def inference_2(model, processor, input_text, data_path, file_name):
    # Encode chart figure and tokenize text
    data_path = "/home/mingi/data/vistext/data/test_images"

    ver_image = Image.open(os.path.join(data_path, "vertical", file_name)).convert(
        "RGB"
    )
    ver_inputs = processor(text=input_text, images=ver_image, return_tensors="pt")
    ver_prompt_length = ver_inputs["input_ids"].shape[1]
    ver_inputs = {k: v.to("cuda") for k, v in ver_inputs.items()}

    hor_image = Image.open(os.path.join(data_path, "horizontal", file_name)).convert(
        "RGB"
    )
    hor_inputs = processor(text=input_text, images=hor_image, return_tensors="pt")
    hor_prompt_length = hor_inputs["input_ids"].shape[1]
    hor_inputs = {k: v.to("cuda") for k, v in hor_inputs.items()}

    # Generate
    generate_ids = model.generate(**ver_inputs, num_beams=4, max_new_tokens=512)
    ver_output_text = processor.batch_decode(
        generate_ids[:, ver_prompt_length:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]
    print(ver_output_text)

    generate_ids = model.generate(**hor_inputs, num_beams=4, max_new_tokens=512)
    hor_output_text = processor.batch_decode(
        generate_ids[:, hor_prompt_length:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]
    print(hor_output_text)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(ver_image)
    axes[0].axis("off")
    axes[1].imshow(hor_image)
    axes[1].axis("off")

    plt.show()

In [None]:
file_names[11]

'7411.png'

In [6]:
data_path = "/home/mingi/data/vistext/data/test_images/vertical"
input_text = "What was the average ticket price for Nashville Predators games in the 2011/12 season?"
input_text = "Values in the chart: "
# input_text = "Make a table that corresponds to this chart."
file_names = os.listdir(data_path)
# input_text = "convert the top of chart to a table"
for i, file_name in enumerate(file_names):
    file_name = file_names[500]
    inference_2(model, processor, input_text, data_path, file_name)
    if i == 0:
        break

NameError: name 'inference_2' is not defined