In [1]:
pip install evaluate

Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering
from PIL import Image
from datasets import load_dataset
import tqdm
import evaluate

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

device = "cuda" if torch.cuda.is_available() else "cpu"

checkpoint_path = "BLIP_Augument.pt"
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)

dataset = load_dataset("mdwiratathya/SLAKE-vqa-english")
eval_split = dataset["validation"]

bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

predictions = []
references = []
exact_match = 0

for sample in tqdm.tqdm(eval_split, desc="Evaluating"):
    image = sample["image"]
    question = sample["question"]
    ref_answer = sample["answer"]

    inputs = processor(image, question, return_tensors="pt").to(device)

    out = model.generate(**inputs, max_new_tokens=20)
    pred_answer = processor.decode(out[0], skip_special_tokens=True).strip()

    predictions.append(pred_answer)
    references.append(ref_answer)

    if pred_answer.lower() == ref_answer.lower():
        exact_match += 1

accuracy = exact_match / len(eval_split)
bleu_result = bleu_metric.compute(predictions=predictions, references=references)
rouge_result = rouge_metric.compute(predictions=predictions, references=references)

print("Evaluation Results on SLAKE (Validation):")
print(f"Accuracy (Exact Match): {accuracy:.4f}")
print("BLEU:", bleu_result)
print("ROUGE:", rouge_result)


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Evaluating: 100%|██████████████████████████████████████████████████████████████████| 1053/1053 [01:09<00:00, 15.07it/s]


Evaluation Results on SLAKE (Validation):
Accuracy (Exact Match): 0.5309
BLEU: {'bleu': 0.15052548527053278, 'precisions': [0.1331877729257642, 0.14890885750962773, 0.15917602996254682, 0.16262135922330098], 'brevity_penalty': 1.0, 'length_ratio': 1.1435705368289637, 'translation_length': 1832, 'reference_length': 1602}
ROUGE: {'rouge1': 0.598771458364437, 'rouge2': 0.10997293304985614, 'rougeL': 0.5982606041139934, 'rougeLsum': 0.5986836886291129}
