# Chart Understanding with GPT-4o and Tool Use (aka Structured Generation)

NOTE: NOT ELIGIBLE FOR USE IN THE CHALLENGE

In [2]:
!pip install -q openai

In [3]:
import os
import json
import base64

import pandas as pd

In [4]:
from openai import OpenAI
client = OpenAI(api_key="sk-<secret>")

In [9]:
def convert_image_to_base64(image_path: str):
    with open(image_path, "rb") as f:
        encoded_image = base64.b64encode(f.read())
    return encoded_image.decode("utf-8")

In [15]:
def infer_chart_type(image_path: str) -> str:
    image_path_parts = image_path.split("/")
    print(f"{image_path_parts = }")
    if "_vbar_" in image_path_parts[-1]:
        return "Vertical Bar Chart"
    elif "_hbar_" in image_path_parts[-1]:
        return "Horizontal Bar Chart"
    elif "_line_" in image_path_parts[-1]:
        return "Line Chart"
    elif "_sct_" in image_path_parts[-1]:
        return "Scatter Plot"
    elif "_pie_" in image_path_parts[-1]:
        return "Pie Chart"
    else:
        return "Other Chart Type"

In [10]:
def build_tool():
    return {
        "type": "function",
        "function": {
            "name": "chart_information_extraction_tool",
            "description": "Extract information from a chart",
            "parameters": {
                "type": "object",
                "properties": {
                    "categories": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "Categories in the chart. They are on the X-axis in vertical bar charts, line charts, and scatter plots. But they are on the Y-axis in horizontal bar charts.",
                    },
                    "groups": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "Groups of data having the same color in the chart. Can be found in the legend."
                    },
                    "reasoning": {"type": "string"},
                    "answer": {
                        "type": "string",
                        "description": "Concise answer to the user question."
                    },
                },
                "required": ["categories", "groups", "reasoning", "answer"],
            },
        }
    }

In [11]:
SYSTEM_PROMPT = "You are a scientific chart explainer. You will receive an image with a chart in it as an input and you must answer the user's question based on the data from the chart. Be concise and don't yap. If you are uncertain, provide a range of possible answers. Output in json format."

In [20]:
def run_inference(image_path, question, model="gpt-4o-2024-05-13", seed=0):
    image_base64 = convert_image_to_base64(image_path)
    question_trimmed = question[len("<image>\n"):]
    chart_type = infer_chart_type(image_path)
    print(f"{chart_type = } | {question_trimmed = }")
    tool = build_tool()
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": SYSTEM_PROMPT,
                    }
                ]
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{image_base64}"},
                    },
                    {
                        "type": "text",
                        "text": f"This is a {chart_type}. {question}",
                    }
                ]
            },
        ],
        temperature = 1,
        # max_tokens=256,
        seed=seed,
        top_p = 1,
        frequency_penalty = 0,
        presence_penalty = 0,
        tools = [tool],
        tool_choice = {
            "type": "function",
            "function": {"name": tool["function"]["name"]},
        },
    )
    response_args = response.choices[0].message.tool_calls[0].function.arguments
    print(f"{response_args = }")
    return json.loads(response_args)["answer"]

In [21]:
image_path = "data/raw_datasets/mychart/images/133_vbar_01ac38ed852061cf4eac77cd85c402bb6c8c9dc26c5fd2071a12aeae98ec84a3_81.png"

In [None]:
run_inference(image_path, "<image>\nHow many deferred trust units were granted during the year?")

In [None]:
dataset_name = "mychart"
dataset_path = os.path.join("data/raw_datasets", dataset_name, "annot_wo_answer.json")
print(dataset_path)
assert os.path.exists(dataset_path)

df_data = pd.read_json(dataset_path)

In [24]:
!mkdir -p inference_results/gpt-4o

In [None]:
failed_idx = set()
for idx, row in df_data.iterrows():
    id = row["id"]
    answer_txt_path = f"inference_results/gpt-4o/{id}.txt"
    if os.path.exists(answer_txt_path):
        continue
    print(id, idx)

    image_path = f"data/raw_datasets/{dataset_name}/images/{row['image']}"
    question = row["conversations"][0]["value"]

    try:
        answer = run_inference(image_path, question)
        with open(answer_txt_path, "w") as f:
            f.write(answer)
    except Exception as e:
        print(idx, row, e)
        failed_idx.add(idx)

In [None]:
failed_idx = set(failed_idx)
failed_idx

In [81]:
for idx, row in df_data.iterrows():
    id = row["id"]
    answer_txt_path = f"inference_results/gpt-4o/{id}.txt"
    if not os.path.exists(answer_txt_path):
        failed_idx.add(idx)
        continue

    with open(answer_txt_path, "r") as f:
        answer = f.read()

    if len(answer) >= 50:
        failed_idx.add(idx)

In [None]:
failed_idx

In [None]:
df_data.iloc[list(failed_idx)]

In [None]:
failed_idx_2 = set()
for idx, row in df_data.iloc[list(failed_idx)].iterrows():
    id = row["id"]
    answer_txt_path = f"inference_results/gpt-4o/{id}.txt"
    if os.path.exists(answer_txt_path):
        os.remove(answer_txt_path)
    print(id, idx)

    image_path = f"data/raw_datasets/{dataset_name}/images/{row['image']}"
    question = row["conversations"][0]["value"]

    try:
        answer = run_inference(image_path, question, seed=42)
        with open(answer_txt_path, "w") as f:
            f.write(answer)
    except Exception as e:
        print(idx, row, e)
        failed_idx_2.add(idx)

In [None]:
failed_idx_2