In [None]:
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText

dataset_dir = 'RISCM'
images_dir = os.path.join(dataset_dir, 'resized')
captions_path = os.path.join(dataset_dir, 'captions.csv')
output_csv = os.path.join(dataset_dir, 'paligemma_predictions_test.csv')
model_name = "google/paligemma-3b-pt-224"
hf_token = " " # Enter huggingface token for permission
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_csv(captions_path)
df.columns = df.columns.str.strip()

test_df = df[df["split"].str.lower() == "test"].reset_index(drop=True)

processor = AutoProcessor.from_pretrained(model_name, token=hf_token)
model = AutoModelForImageTextToText.from_pretrained(model_name, token=hf_token).to(device)

def load_image(image_filename):
    path = os.path.join(images_dir, image_filename.strip())
    return Image.open(path).convert('RGB')

def generate_caption(img):
    input_text = "<image>Describe this image in detail\n"
    inputs = processor(text=input_text, images=img, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device)
    inputs = inputs.to(dtype=model.dtype)
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=64)
    return processor.decode(output[0], skip_special_tokens=True)

# Inference on the test set
generated_captions = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df)):
    try:
        img = load_image(row['image'])
        pred_caption = generate_caption(img)
        generated_captions.append(pred_caption)
    except Exception as e:
        print(f"Error on image {row['image']}: {e}")
        generated_captions.append("")

# Filter the heading and save predictions
heading = 'Describe this image in detail\n\n'
cleaned_captions = [caption.replace(heading, '') for caption in generated_captions]
for caption in cleaned_captions:
    print(caption)

test_df["predicted_caption"] = cleaned_captions
test_df.to_csv(output_csv, index=False)
print(f"\nPredictions saved to: {output_csv}")