In [None]:
import os
import csv
import torch
import pandas as pd
from PIL import Image
from tqdm import tqdm
from transformers import BlipProcessor, BlipForQuestionAnswering

# Config
SRC_PATH     = "../data/vqa.csv"
IMAGE_SRC_DIR = "../data/curated_images"
DEST_PATH    = "../data/preds_blip.csv"
MODEL_PATH = '../models/blip_ft.pth'

SEED        = 7
SAMPLE_SIZE = 10000
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"

# Load data
df = pd.read_csv(SRC_PATH)
df_sample = df.sample(n=SAMPLE_SIZE, random_state=SEED).reset_index(drop=True)

# Initialize processor and model (Hugging Face checkpoint)
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

# If a fine-tuned checkpoint exists locally, load it
if os.path.exists(MODEL_PATH):
    print(f"Loading fine-tuned model from {MODEL_PATH}")
    state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(state_dict)
else:
    print("No local checkpoint found. Using the pre-trained Hugging Face model.")

# Move model to device and set to eval mode
model.to(DEVICE)
model.eval()

# Open output CSV
with open(DEST_PATH, mode="w", newline="", encoding="utf-8") as out_file:
    writer = csv.writer(out_file)
    writer.writerow(["filename", "question", "answer", "prediction"])

    # Inference loop
    for row in tqdm(df_sample.itertuples(index=False), total=len(df_sample), desc="BLIP VQA Inference", unit="it"):
        filename, question, answer = row.filename, row.question, row.answer
        img_path = os.path.join(IMAGE_SRC_DIR, filename)
        try:
            img = Image.open(img_path).convert("RGB")
            inputs = processor(images=img, text=question, return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                out_ids = model.generate(**inputs, max_new_tokens=5)
            prediction = processor.decode(out_ids[0], skip_special_tokens=True).strip().lower()
        except Exception as e:
            prediction = ""
        writer.writerow([filename, question, answer, prediction])

print(f"Saved predictions to {DEST_PATH}")