In [3]:
import torch
from transformers import BertTokenizer, BertForQuestionAnswering
from PIL import Image
from torchvision.transforms import functional as F

def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    return image

def preprocess_image(image):
    # Resize image to match the model's expected sizing
    image = F.resize(image, (256, 256))
    # Convert PIL image to PyTorch tensor
    image = F.to_tensor(image)
    # Normalize the image
    image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # Add batch dimension
    image = image.unsqueeze(0)
    return image

def ask_question(question, model, tokenizer, image_tensor):
    inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    answer_start_scores = outputs.start_logits
    answer_end_scores = outputs.end_logits

    # Get the most likely answer
    answer_start = torch.argmax(answer_start_scores)
    answer_end = torch.argmax(answer_end_scores) + 1

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0, answer_start:answer_end]))
    return answer

def main():
    # Load pre-trained VQA model
    model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Load and preprocess image
    image_path = "sample_graph.webp"
    image = load_image(image_path)
    image_tensor = preprocess_image(image)

    # Ask a question
    question = "What is in the image?"
    answer = ask_question(question, model, tokenizer, image_tensor)

    # Print the answer
    print(f"Question: {question}")
    print(f"Answer: {answer}")

In [4]:
main()

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Question: What is in the image?
Answer: [CLS] what is in the
