In [14]:
import os
import json
import requests
from datasets import load_dataset, get_dataset_infos
import base64
from PIL import Image
from io import BytesIO
from groq import Groq

In [31]:
# Get dataset information
dataset_name = "keremberke/german-traffic-sign-detection"
infos = get_dataset_infos(dataset_name)

In [None]:
# Load dataset
dataset = load_dataset(dataset_name, "full", split="test")

In [42]:
# Get image_id
image_id = dataset[0]["image_id"]

In [43]:
# Create base64 image
def encode_image(image: Image.Image) -> str:
    """Convert PIL Image to base64 string. Necessary for Groq API."""
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

base64_image = encode_image(dataset[0]["image"])

In [15]:
client = Groq(api_key=os.getenv("GROQ_API_KEY"))


In [16]:
chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "What's the content of the image?"},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{base64_image}",
                    },
                },
            ],
        }
    ],
    model="llama-3.2-11b-vision-preview",
)

In [44]:
# Create a new csv file with the dataset_name, image_id, and image_caption as columns
# After creating the columns, add the dataset_name, image_id, and image_caption to the csv file
with open("image_captions.csv", "w") as f:
    f.write("dataset_name,image_id,image_caption\n")
    f.write(f"{dataset_name},{image_id},{chat_completion.choices[0].message.content}\n")
