In [18]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import pandas as pd
import torch
import mlflow
import numpy as np

In [8]:
mlflow.set_tracking_uri(
    "http://localhost:5000"
)

In [14]:
exp = mlflow.set_experiment("check-connection")

with mlflow.start_run():
    mlflow.log_metric("foo", 1)
    mlflow.log_metric("bar", 2)

mlflow.delete_experiment(exp.experiment_id)

2024/09/16 20:32:45 INFO mlflow.tracking.fluent: Experiment with name 'check-connection' does not exist. Creating a new experiment.
2024/09/16 20:32:45 INFO mlflow.tracking._tracking_service.client: 🏃 View run sneaky-fawn-598 at: http://localhost:5000/#/experiments/588219251072499462/runs/e62831d71d4041019de56e8a38be633e.
2024/09/16 20:32:45 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5000/#/experiments/588219251072499462.


In [2]:
model_id = "google/paligemma-3b-mix-224"


In [4]:
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
train_df = pd.read_csv("../dataset/train.csv")
test_df = pd.read_csv("../dataset/test.csv")

train_sample = train_df.sample(1000)
test_sample = test_df.sample(1000)

In [19]:
def crop_and_resize(image, target_size):
    return image.resize(target_size, Image.Resampling.LANCZOS)

def read_image(url, target_size):
    image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    return image

In [21]:
result = pd.DataFrame({"index": [], "prediction": []})

In [None]:
for (idx, (index, url, id, ent_type)) in enumerate(test_sample.values):
    print(f"{idx}: {url}")
    image = read_image(url, (224, 224))

    if ent_type == "item_weight" or ent_type == "maximum_weight_recommendation":
        ent_type = "net weight"

    prompt = f'answer en What is the item {ent_type}?\n'

    model_inputs = processor(text=prompt, images=image, return_tensors="pt")
    input_len = model_inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]
        output = processor.decode(generation, skip_special_tokens=True)
        print(output)

    result = pd.concat([result, pd.DataFrame({"index": [idx], "prediction": [output.split("\n")[0]]})], ignore_index=True)

    if idx == 5000:
        break
