In [1]:
pip install evaluate

Note: you may need to restart the kernel to use updated packages.


In [4]:
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering
from PIL import Image
from datasets import load_dataset
import tqdm
import evaluate

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

dataset = load_dataset("flaviagiammarino/path-vqa")
eval_split = dataset["validation"]

bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

predictions = []
references = []
exact_match = 0

for sample in tqdm.tqdm(eval_split, desc="Evaluating"):
    image = sample["image"]
    question = sample["question"]
    ref_answer = sample["answer"]

    inputs = processor(image, question, return_tensors="pt").to(device)

    out = model.generate(**inputs, max_new_tokens=20)
    pred_answer = processor.decode(out[0], skip_special_tokens=True).strip()

    predictions.append(pred_answer)
    references.append(ref_answer)

    if pred_answer.lower() == ref_answer.lower():
        exact_match += 1

accuracy = exact_match / len(eval_split)
bleu_result = bleu_metric.compute(predictions=predictions, references=references)
rouge_result = rouge_metric.compute(predictions=predictions, references=references)

print("Evaluation Results on PathVQA (Validation):")
print(f"Accuracy (Exact Match): {accuracy:.4f}")
print("BLEU:", bleu_result)
print("ROUGE:", rouge_result)


Evaluating: 100%|██████████████████████████████████████████████████████████████████| 6259/6259 [06:09<00:00, 16.96it/s]


Evaluation Results on PathVQA (Validation):
Accuracy (Exact Match): 0.2620
BLEU: {'bleu': 0.0, 'precisions': [0.23372125242449432, 0.0010427528675703858, 0.0, 0.0], 'brevity_penalty': 0.5712153948239251, 'length_ratio': 0.6410301953818828, 'translation_length': 7218, 'reference_length': 11260}
ROUGE: {'rouge1': 0.26384421878486897, 'rouge2': 3.106637553034741e-05, 'rougeL': 0.2639046012731435, 'rougeLsum': 0.2637448762387675}
