In [None]:
from datasets import load_dataset
import pandas as pd

# Load the dataset (it has multiple splits)
dataset = load_dataset("nielsr/countbench")

# Let's assume you're using the 'train' split
train_data = dataset["train"]

# Convert the first 10 examples to a Pandas DataFrame
df = pd.DataFrame(train_data)

# Show the DataFrame
print(df)


In [None]:
import openai
import base64
from io import BytesIO
import pandas as pd

API_KEY = "INSERT API KEY"

client = openai.OpenAI(api_key=API_KEY)

# Convert PIL image to base64 image URL dict
def pil_to_base64_image_url(image: "PIL.Image.Image"):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return {
        "url": f"data:image/png;base64,{b64}",
        "detail": "high"
    }

# Query GPT-4o with image+caption
def get_named_object_from_image(image, caption=None, model="gpt-4o"):
    if image is None:
        return "IMAGE_MISSING"

    try:
        image_dict = pil_to_base64_image_url(image)

        messages = [
            {"role": "system", "content": "You are a helpful visual assistant."},
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": (
                            "You will be shown an image. Identify the main object(s) in the image. "
                            "Respond with just the object name(s) in lowercase, one or two words max. "
                            "If there are multiple objects (like two kids), just say 'kids'. Don't list more than one object, only the object that is repeating. "
                            "Avoid sentences and do not include counts."
                        ),
                    },
                    {
                        "type": "image_url",
                        "image_url": image_dict
                    }
                ],
            }
        ]

        if caption:
            messages[1]["content"].insert(0, {"type": "text", "text": f"The caption says: {caption}"})

        response = client.chat.completions.create(
            model=model,
            messages=messages,
            max_tokens=50,
            temperature=0,
            top_p=1,
        )

        result = response.choices[0].message.content.strip()
        return result

    except Exception as e:
        return "ERROR"

# Apply to DataFrame
def name_objects_in_dataframe(df, model="gpt-4o"):
    results = []
    for idx, row in df.iterrows():
        print(f"\n Processing row {idx}")
        result = get_named_object_from_image(row["image"], caption=row.get("text", None), model=model)
        results.append(result)

    df["gpt_named_object"] = results
    return df


In [None]:
df = name_objects_in_dataframe(df, model="gpt-4o")

In [None]:
df = df[df["gpt_named_object"] != "IMAGE_MISSING"]

In [None]:
df = df[~df["gpt_named_object"].str.contains("sorry")]

In [None]:
df.to_csv("gpt_responses_cleaned.csv", index=False)

In [None]:
df["prompt_count"] = df["gpt_named_object"].apply(
    lambda obj: f"How many {obj} are in this image?"
)

# Prompts 2â€“6: Describe the (number + x) <objects>
for i in range(1, 6):
    df[f"prompt_describe_plus_{i}"] = df.apply(
        lambda row: f"Describe the {row['number'] + i} {row['gpt_named_object']} in this image."
        if isinstance(row['number'], int)
        else f"Describe the {row['gpt_named_object']} in this image.",
        axis=1
    )

# Show relevant columns
columns_to_show = ["gpt_named_object", "number", "prompt_count"] + [f"prompt_describe_plus_{i}" for i in range(1, 6)]
print(df[columns_to_show])

In [None]:
df_hf = train_data

# Merge on shared key: we assume 'image_url' is the key both have
merged_df = df.drop(columns=["image"], errors="ignore").merge(
    df_hf[["image_url", "image"]], on="image_url", how="left"
)

In [None]:

# Save images and create 'path' column
os.makedirs("countbench_images", exist_ok=True)

def save_image(row, folder="countbench_images"):
    try:
        img = row["image"]
        filename = f"{row.name:04d}.png"
        path = os.path.join(folder, filename)
        img.save(path)
        return path
    except Exception as e:
        print(f"Failed to save image at index {row.name}: {e}")
        return None

merged_df["path"] = merged_df.apply(save_image, axis=1)


In [None]:
import os

image_folder = "countbench_images"
num_images = len([
    f for f in os.listdir(image_folder)
    if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])

print(f"Number of images in '{image_folder}': {num_images}")


In [None]:
for i in [10, 20, 50]:
    merged_df[f"prompt_describe_plus_{i}"] = merged_df.apply(
        lambda row: f"Describe the {row['number'] + i} {row['gpt_named_object']} in this image."
        if isinstance(row['number'], int)
        else f"Describe the {row['gpt_named_object']} in this image.",
        axis=1
    )



In [None]:
merged_df.to_csv("counting_with_prompts.csv", index=False)

In [6]:
[f"prompt_describe_plus_{i}" for i in [1, 2, 3, 4, 5, 10, 20, 50]]

['prompt_describe_plus_1',
 'prompt_describe_plus_2',
 'prompt_describe_plus_3',
 'prompt_describe_plus_4',
 'prompt_describe_plus_5',
 'prompt_describe_plus_10',
 'prompt_describe_plus_20',
 'prompt_describe_plus_50']