In [1]:
import fiftyone as fo 
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset(
    "quickstart",
    max_samples=2,
    dataset_name="smol_start", 
    shuffle=True,
    overwrite=True
)

dataset.compute_metadata()

Overwriting existing directory '/home/harpreet/fiftyone/quickstart'
Downloading dataset to '/home/harpreet/fiftyone/quickstart'
Downloading dataset...
 100% |████|  187.5Mb/187.5Mb [488.3ms elapsed, 0s remaining, 384.0Mb/s]      
Extracting dataset...
Parsing dataset metadata
Found 200 samples
Dataset info written to '/home/harpreet/fiftyone/quickstart/info.json'
Loading existing dataset 'smol_start'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


In [59]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch

model_id = "google/paligemma2-3b-mix-224"

url = dataset.first().filepath

image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id, use_fast=True)

prompt = """<image> segment hen; grass; baby chick\n"""
model_inputs = processor(
    text=prompt, 
    images=image, 
    padding="longest",
    return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=3092)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

<loc0001><loc0000><loc1010><loc1020><seg010><seg090><seg090><seg000><seg054><seg082><seg082><seg027><seg035><seg082><seg082><seg027><seg023><seg090><seg082><seg010>


<loc0001><loc0243><loc0963><loc0666><seg074><seg074><seg069><seg038><seg104><seg056><seg023><seg030><seg070><seg099><seg099><seg073><seg014><seg087><seg019><seg118>

In [43]:
def create_prompt(task: str, prompt: str | list | None = None) -> str:
    """Create a formatted prompt string for PaliGemma2 vision-language tasks.
    
    Args:
        task: The vision task to perform. Must be one of:
            - "cap", "caption", "describe": Captioning with different detail levels
            - "ocr": Optical character recognition
            - "answer", "question": Visual QA tasks
            - "detect": Object detection
            - "segment": Instance segmentation
        prompt: Main task input. Could be:
            - Question for "answer" task
            - Answer for "question" task
            - Objects for "detect" task (string or list, joined with " ; ")
            - Object for "segment" task
    
    Returns:
        str: Formatted prompt string ready for model input
        
    Examples:
        >>> create_prompt("detect", ["car", "person"])
        '<image> detect car ; person'
        >>> create_prompt("answer", "What color is the car?")
        '<image> answer en What color is the car?'
        >>> create_prompt("ocr")
        '<image> ocr'
    """
    # Handle OCR as special case with no parameters
    if task == "ocr":
        return "<image> ocr"
        
    # Process list inputs for detection/segmentation
    if isinstance(prompt, (list, tuple)):
        prompt = " ; ".join(str(p) for p in prompt)
    
    # Build task-specific prompt
    if task in ["cap", "caption", "describe"]:
        return f"<image> {task} en"
    elif task in ["answer", "question"]:
        return f"<image> {task} en {prompt}"
    elif task in ["detect", "segment"]:
        return f"<image> {task} {prompt}"
    else:
        raise ValueError(f"Unknown task: {task}")

In [45]:
create_prompt(task="caption")

'<image> caption en'

In [49]:
create_prompt(task="answer", prompt="What color is the car?")

'<image> answer en What color is the car?'

In [52]:
create_prompt(task="detect", prompt="car; boat; house")

'<image> detect car; boat; house'

In [53]:
create_prompt(task="detect", prompt=["car", "boat", "house"])

'<image> detect car ; boat ; house'

In [54]:
create_prompt(task="segment", prompt=["car", "boat", "house"])

'<image> segment car ; boat ; house'