In [1]:
pip install evaluate

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering
from PIL import Image
from datasets import load_dataset
import tqdm
import evaluate

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

dataset = load_dataset("flaviagiammarino/vqa-rad")
eval_split = dataset["test"]

bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

predictions = []
references = []
exact_match = 0

for sample in tqdm.tqdm(eval_split, desc="Evaluating"):
    image = sample["image"]
    question = sample["question"]
    ref_answer = sample["answer"]

    inputs = processor(image, question, return_tensors="pt").to(device)

    out = model.generate(**inputs, max_new_tokens=20)
    pred_answer = processor.decode(out[0], skip_special_tokens=True).strip()

    predictions.append(pred_answer)
    references.append(ref_answer)

    if pred_answer.lower() == ref_answer.lower():
        exact_match += 1

accuracy = exact_match / len(eval_split)
bleu_result = bleu_metric.compute(predictions=predictions, references=references)
rouge_result = rouge_metric.compute(predictions=predictions, references=references)

print("Evaluation Results on RadVQA (Validation):")
print(f"Accuracy (Exact Match): {accuracy:.4f}")
print("BLEU:", bleu_result)
print("ROUGE:", rouge_result)


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 451/451 [00:26<00:00, 17.21it/s]


Evaluation Results on RadVQA (Validation):
Accuracy (Exact Match): 0.3304
BLEU: {'bleu': 0.0, 'precisions': [0.3133208255159475, 0.0, 0.0, 0.0], 'brevity_penalty': 0.7040933883796072, 'length_ratio': 0.7402777777777778, 'translation_length': 533, 'reference_length': 720}
ROUGE: {'rouge1': 0.3504532976262467, 'rouge2': 0.0033259423503325942, 'rougeL': 0.35034771215480753, 'rougeLsum': 0.3511914360583984}
