In [1]:
import sys 
sys.path.insert(0, '/home/ray/default')

import os
from util.utils import (
    generate_output_path,
    prompt_for_hugging_face_token
)
import ray
import os
from typing import Dict
import numpy as np

In [2]:
## https://huggingface.co/datasets/DBQ/Burberry.Product.prices.United.States?row=0
HF_DATA = "DBQ/Burberry.Product.prices.United.States"

BASE_PATH = 's3://anyscale-customer-dataplane-data-production-us-east-2/artifact_storage/org_6687q89lgh27q3z41zesm2fsq6/cld_j25ipm5kli358v41pn9c96gjg3/BurberryData:john_:kpbdm'
IMG_PATH = BASE_PATH + "/images"
DATA_PATH = BASE_PATH + "/data"
CAPTION_PATH = BASE_PATH + "/captions/2"

IMG_PATH_TEST = "/home/ray/default/data/images"

In [3]:
HF_MODEL = "google/paligemma-3b-mix-224"

#### Run Config
There are two modes `test` and `prod`. Test will only operate on a small subset of the data

In [4]:
from enum import Enum, IntEnum
from pydantic import BaseModel, ValidationError

class RunMode(str, Enum):
    test = 'test'
    prod = 'prod'

In [5]:
mode = RunMode.prod

## Read Images

In [6]:
LIMIT = 10 if mode==RunMode.test else 100
img_data = ray.data.read_images(IMG_PATH, include_paths=True).limit(LIMIT)

2024-09-22 06:06:02,171	INFO worker.py:1596 -- Connecting to existing Ray cluster at address: 100.95.59.40:6379...
2024-09-22 06:06:02,178	INFO worker.py:1772 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-33eser3czbha2i3jm7g2t5am82.i.anyscaleuserdata.com [39m[22m
2024-09-22 06:06:02,180	INFO packaging.py:358 -- Pushing file package 'gcs://_ray_pkg_8c3a37d44ac6671b254b4766a2009901138346c5.zip' (0.04MiB) to Ray cluster...
2024-09-22 06:06:02,181	INFO packaging.py:371 -- Successfully pushed file package 'gcs://_ray_pkg_8c3a37d44ac6671b254b4766a2009901138346c5.zip'.


## Inference with PaliGemma

In [12]:
class PaliGemmaPredictor:
    def __init__(self, prompt="caption en", image_col="image"):
        from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
        self.prompt = prompt
        self.image_col = image_col
        self.model_id = "google/paligemma-3b-mix-224"
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval()
        self.processor = AutoProcessor.from_pretrained(self.model_id)

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
        import torch
        # Generate texts from the prompts.
        # The output is a list of RequestOutput objects that contain the prompt,
        # generated text, and other information.
        images = list(batch[self.image_col])
        prompts = [self.prompt] * len(images)
        model_inputs = self.processor(text=prompts, images=images, return_tensors="pt")
        input_len = model_inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            generation = self.model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
            mask = torch.tensor([i>=input_len for i in range(generation.shape[1])]).repeat(generation.shape[0],1)
            indices = torch.nonzero(mask, as_tuple=True)
            decoded = self.processor.batch_decode(generation[indices].reshape(generation.shape[0],-1), skip_special_tokens=True)
        
        return {
            "captions": decoded,
            "path": batch['path'].tolist()
        }


In [8]:

batch = img_data.take_batch(10)

2024-09-22 06:06:46,882	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-09-21_20-17-17_494561_3381/logs/ray-data
2024-09-22 06:06:46,883	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ExpandPaths] -> TaskPoolMapOperator[ReadFiles] -> LimitOperator[limit=100] -> LimitOperator[limit=10]


- ExpandPaths 1: 0 bundle [00:00, ? bundle/s]

- ReadFiles 2: 0 bundle [00:00, ? bundle/s]

- limit=100 3: 0 bundle [00:00, ? bundle/s]

- limit=10 4: 0 bundle [00:00, ? bundle/s]

Running 0: 0 bundle [00:00, ? bundle/s]

In [13]:
PaliGemmaPredictor()(batch)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

{'captions': ['a plaid bag with a zipper on it',
  'A tan jacket with a zipper on the front.',
  'A black boot with a plaid pattern on the inside.',
  'A beige and black and white cape with a fringe.',
  'A tan turtleneck shirt with a white logo on the front.',
  'A blue shirt with a long-sleeved design.',
  "a navy blue sweatshirt with a bear on it that says 'burberry'",
  'A pair of plaid shorts with a black string.',
  'A black sweater with a zipper on the front.',
  'A pillow with a picture of a horse on it.'],
 'path': ['anyscale-customer-dataplane-data-production-us-east-2/artifact_storage/org_6687q89lgh27q3z41zesm2fsq6/cld_j25ipm5kli358v41pn9c96gjg3/BurberryData:john_:kpbdm/images/2B7BA48A-64F4-464F-8568-E9A4C47968B1.png',
  'anyscale-customer-dataplane-data-production-us-east-2/artifact_storage/org_6687q89lgh27q3z41zesm2fsq6/cld_j25ipm5kli358v41pn9c96gjg3/BurberryData:john_:kpbdm/images/2BB5C565-03E1-41FE-BB61-B1A3D8C77AC0.png',
  'anyscale-customer-dataplane-data-production-us

In [17]:
ds = (
    img_data
    .map_batches(
        PaliGemmaPredictor,
        concurrency=3,
        batch_size=10,
        fn_constructor_kwargs={"image_col": "image"}
    )
)

In [18]:
ds.write_parquet(
        path=CAPTION_PATH,
        try_create_dir=False
    )

2024-09-22 06:18:42,253	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-09-21_20-17-17_494561_3381/logs/ray-data
2024-09-22 06:18:42,253	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ExpandPaths] -> TaskPoolMapOperator[ReadFiles] -> LimitOperator[limit=100] -> ActorPoolMapOperator[MapBatches(PaliGemmaPredictor)] -> TaskPoolMapOperator[Write]


- ExpandPaths 1: 0 bundle [00:00, ? bundle/s]

- ReadFiles 2: 0 bundle [00:00, ? bundle/s]

- limit=100 3: 0 bundle [00:00, ? bundle/s]

- MapBatches(PaliGemmaPredictor) 4: 0 bundle [00:00, ? bundle/s]

- Write 5: 0 bundle [00:00, ? bundle/s]

Running 0: 0 bundle [00:00, ? bundle/s]