In [None]:
import warnings
from tqdm import TqdmWarning
warnings.filterwarnings("ignore", category=TqdmWarning)

import torch
from paligemma_preprocessor import PaliGemmaPreprocessor
from utils import load_huggingface_weights_into_model, inference

In [None]:
# parameters
MODEL_PATH = "paligemma-3b-pt-224"
IMAGE_PATH = "images/car.jpg"
PROMPT = "caption es"
MAX_TOKENS = 100
DO_SAMPLE = True
TEMPERATURE = 0.8
TOP_P = 0.9

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Using device: {DEVICE}')

model, tokenizer = load_huggingface_weights_into_model(MODEL_PATH, DEVICE)
model.eval()

preprocessor = PaliGemmaPreprocessor(
    tokenizer=tokenizer,
    num_image_tokens=model.config.num_image_tokens,
    image_size=model.config.vision_config.image_size
)

In [None]:
with torch.no_grad():
    inference(
        model=model,
        preprocessor=preprocessor,
        image_path=IMAGE_PATH,
        prompt=PROMPT,
        max_tokens=MAX_TOKENS,
        do_sample=DO_SAMPLE,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        device=DEVICE
    )