In [None]:
import os
import torch
from PIL import Image
import pandas as pd
from tqdm import tqdm
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
from google.colab import drive, userdata
from concurrent.futures import ThreadPoolExecutor
from functools import partial

In [None]:
drive.mount('/content/drive')
%cd /content/drive/MyDrive/

hf_token = userdata.get('HF_TOKEN')
model_name = "google/paligemma-3b-pt-224"
lora_checkpoint = "./qlora_output/checkpoint-7500"
dataset_dir = 'RISCM'
images_dir = os.path.join(dataset_dir, 'resized')
captions_path = os.path.join(dataset_dir, 'captions_cleaned.csv')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    token=hf_token
)

model = PeftModel.from_pretrained(model, lora_checkpoint, is_trainable=False)
model = model.to(device)
model.eval()

processor = PaliGemmaProcessor.from_pretrained(model_name, token=hf_token)

In [None]:
# Running single inference

df = pd.read_csv(captions_path, usecols=['image', 'split', 'training_caption'])
df.columns = df.columns.str.strip()
test_df = df[df["split"].str.lower() == "test"].reset_index(drop=True)

sample_row = test_df.iloc[56]
image_file = sample_row['image'].strip()
image_path = os.path.join(images_dir, image_file)
prompt = "<image> <bos> Describe this image in detail:"
image = Image.open(image_path).convert("RGB")

@torch.inference_mode()
def generate_caption(image):
    inputs = processor(
        text=prompt,
        images=image,
        padding="max_length",
        return_tensors="pt",
        do_convert_rgb=True
    ).to(device, dtype=model.dtype)

    outputs = model.generate(
        **inputs,
        max_new_tokens=16,
        num_beams=1,
        do_sample=False
    #    num_beams=4,
    #    do_sample=True,
    #    top_p=0.9,
    #    top_k=50,
    #    temperature=0.8,
    #    early_stopping=True,
    #    length_penalty=1.0,
    #    no_repeat_ngram_size=2
    )

    caption = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    return caption.replace("Describe this image in detail:", "").strip()

generated_caption = generate_caption(image)
print("Image file:", image_file)
print("Generated Caption:", generated_caption)

In [None]:
# Running inference on the test set

batch_size = 32
fraction = 1.0
output_csv = os.path.join(dataset_dir, 'predictions.csv')

df = pd.read_csv(captions_path, usecols=['image', 'split', 'training_caption'])
df.columns = df.columns.str.strip()
test_df = df[df["split"].str.lower() == "test"].sample(frac=fraction, random_state=42).reset_index(drop=True)
image_names = test_df['image'].tolist()

def load_single_image(fname, images_dir):
    img_path = os.path.join(images_dir, fname.strip())
    with Image.open(img_path) as img:
        return fname, img.convert("RGB")

def load_images_parallel(batch_filenames, num_workers=4):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        load_fn = partial(load_single_image, images_dir=images_dir)
        results = list(executor.map(load_fn, batch_filenames))

    valid_pairs = [(f, img) for f, img in results if img is not None]
    if valid_pairs:
        return tuple(zip(*valid_pairs))
    else:
        return [], []

@torch.inference_mode()
def generate_captions(images):
    input_texts = ["<image> <bos> Describe this image in detail:" for _ in images]

    inputs = processor(
        text=input_texts,
        images=images,
        padding="max_length",
        return_tensors="pt",
        do_convert_rgb=True
    ).to(device, dtype=model.dtype)

    outputs = model.generate(
        **inputs,
        max_new_tokens=16,
        num_beams=1,
        do_sample=False
    #    num_beams=4,
    #    do_sample=True,
    #    top_p=0.9,
    #    top_k=50,
    #    temperature=0.8,
    #    early_stopping=True,
    #    length_penalty=1.0,
    #    no_repeat_ngram_size=2
    )

    captions = processor.batch_decode(outputs, skip_special_tokens=True)
    return [caption.replace("Describe this image in detail:", "").strip() for caption in captions]

generated_captions = [""] * len(image_names)

for i in tqdm(range(0, len(image_names), batch_size), desc="Processing images"):
    batch_files = image_names[i:i+batch_size]
    batch_filenames, batch_imgs = load_images_parallel(batch_files)

    if len(batch_imgs) == 0:
        continue

    preds = generate_captions(list(batch_imgs))

    pred_dict = dict(zip(batch_filenames, preds))
    for j, fname in enumerate(batch_files):
        if fname in pred_dict:
            generated_captions[i + j] = pred_dict[fname]

test_df["predicted_caption"] = generated_captions

test_df.to_csv(output_csv, index=False)
print(f"\nPredictions saved to: {output_csv}")