In [32]:
!pip install uv
!uv pip install mistralai datasets python-dotenv pillow rich -q



In [33]:
import base64
from mistralai import Mistral
from dotenv import load_dotenv
from pathlib import Path
from PIL import Image
import os
from io import BytesIO
from datasets import load_dataset
from IPython.display import display, Image as IPythonImage
from rich import print

load_dotenv()
PathLike = Path | str

In [21]:
# Retrieve the API key from environment variables
api_key = os.environ["MISTRAL_API_KEY"]

# Specify model
model = "pixtral-12b-2409"

# Initialize the Mistral client
client = Mistral(api_key=api_key)

In [22]:


ds = load_dataset("tomytjandra/h-and-m-fashion-caption", split="train")

In [23]:
# Choose the first 1000 samples
ds = ds.select(range(1000))


In [24]:
def encode_pil_to_base64(pil_image):
    """
    Convert a PIL Image to base64 encoded string.

    Args:
        pil_image: PIL Image object

    Returns:
        str: base64 encoded string of the image
    """
    # Create a bytes buffer to hold the image data
    buffered = BytesIO()

    # Save the image to the buffer in JPEG format
    # You can adjust format and quality as needed
    pil_image.save(buffered, format="JPEG", quality=95)

    # Get the bytes from the buffer
    img_bytes = buffered.getvalue()

    # Encode to base64 and convert to string
    img_base64 = base64.b64encode(img_bytes).decode("utf-8")

    return f"data:image/jpeg;base64,{img_base64}"

In [42]:

def display_record(record):
    """
    Display a PIL Image inline in Jupyter notebook.
    
    Args:
        pil_image: PIL Image object
    """
    # Create a bytes buffer
    buffered = BytesIO()
    pil_image = record["image"]
    # Save the image to the buffer
    pil_image.save(buffered, format="JPEG")
    
    # Create IPython image object
    ipython_image = IPythonImage(data=buffered.getvalue())
    
    # Display inline
    display(ipython_image), print(record["text"])

In [41]:
def prepare_messages(record):
    messages = [
        {
            "role": "system",
            "content": "Return the answer in a JSON object with the next structure: "
            '{"attributes": [{"attribute": "some name of attribute1", '
            '"values": ["some value of attribute 1", "some value of attribute 1"]}, '
            '{"attribute": "some name of attribute2", "values": '
            '["some value of attribute 2", "some value of attribute 2"]}]}',
        },
        {
            "role": "user",
            "content": "Describe the image and text for an e-commerce catalog. Include all the attributes and values that are present in the image. For instance, collar, sleeves, etc.",
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": encode_pil_to_base64(record["image"]),
                }
            ],
        },
        {
            "role": "user",
            "content": record["text"],
        },
    ]
    return messages


record = ds[37]
display_record(record)

# Get the chat response
chat_response = client.chat.complete(model=model, messages=prepare_messages(record))

# Print the content of the response
print(chat_response.choices[0].message.content)