In [1]:
pip install evaluate

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering
from PIL import Image
from datasets import load_dataset
import tqdm
import evaluate

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

device = "cuda" if torch.cuda.is_available() else "cpu"

checkpoint_path = "BLIP_Augument.pt"
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)

dataset = load_dataset("flaviagiammarino/vqa-rad")
eval_split = dataset["test"]

bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

predictions = []
references = []
exact_match = 0

for sample in tqdm.tqdm(eval_split, desc="Evaluating"):
    image = sample["image"]
    question = sample["question"]
    ref_answer = sample["answer"]

    inputs = processor(image, question, return_tensors="pt").to(device)

    out = model.generate(**inputs, max_new_tokens=20)
    pred_answer = processor.decode(out[0], skip_special_tokens=True).strip()

    predictions.append(pred_answer)
    references.append(ref_answer)

    if pred_answer.lower() == ref_answer.lower():
        exact_match += 1

accuracy = exact_match / len(eval_split)
bleu_result = bleu_metric.compute(predictions=predictions, references=references)
rouge_result = rouge_metric.compute(predictions=predictions, references=references)

print("Evaluation Results on RadVQA (Validation):")
print(f"Accuracy (Exact Match): {accuracy:.4f}")
print("BLEU:", bleu_result)
print("ROUGE:", rouge_result)


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Evaluating: 100%|████████████████████████████████████████████████████████████████████| 451/451 [00:32<00:00, 13.77it/s]

Evaluation Results on RadVQA (Validation):
Accuracy (Exact Match): 0.4102
BLEU: {'bleu': 0.04121902281379124, 'precisions': [0.3034013605442177, 0.03873239436619718, 0.02824858757062147, 0.008695652173913044], 'brevity_penalty': 1.0, 'length_ratio': 1.0208333333333333, 'translation_length': 735, 'reference_length': 720}
ROUGE: {'rouge1': 0.4386121669658255, 'rouge2': 0.010867181598888916, 'rougeL': 0.43667895297607057, 'rougeLsum': 0.4372737511318443}



