Visual Question Answering (VQA) is the task of answering open-ended questions based on an image. The input to models supporting this task is typically a combination of an image and a question, and the output is an answer expressed in natural language.

Some noteworthy use case examples for VQA include:
- Accessibility applications for visually impaired individuals.
- Education: posing questions about visual materials presented in lectures or textbooks. VQA can also be utilized in interactive museum exhibits or historical sites.
- Customer service and e-commerce: VQA can enhance user experience by letting users ask questions about products.
- Image retrieval: VQA models can be used to retrieve images with specific characteristics. For example, the user can ask “Is there a dog?” to find all images with dogs from a set of images.

In this guide:
1. Fine-tune a classification VQA model, specifically ViLT, on the Graphcore/vqa dataset.
2. Use your fine-tuned ViLT for inference.
3. Run zero-shot VQA inference with a generative model, like BLIP-2.

A note on ViLT versus some recent VQA models: ViLT model incorporates text embeddings into a Vision Transformer (ViT), allowing it to have a minimal design for Vision-and-Language Pre-training (VLP). This model can be used for several downstream tasks. For the VQA task, a classifier head is placed on top (a linear layer on top of the final hidden state of the [CLS] token) and randomly initialized. Visual Question Answering is thus treated as a classification problem. More recent models, such as BLIP, BLIP-2, and InstructBLIP, treat VQA as a generative task. Later in this guide we illustrate how to use them for zero-shot VQA inference.



# Libraries

In [None]:
pip install -q transformers datasets

In [None]:
import itertools
from PIL import Image
from datasets import load_dataset
from transformers import ViltProcessor

In [None]:
# Global vars
MODEL_CHECKPOINT = "dandelin/vilt-b32-mlm"

# Load Data

In [None]:
# We'll use a very small sample of the annotated visual question answering Graphcore/vqa dataset
dataset = load_dataset("Graphcore/vqa", split="validation[:200]")
dataset

In [None]:
# Inspect an example
# The features relevant to the task include:
# question: the question to be answered from the image
# image_id: the path to the image the question refers to
# label: the annotations (contains several answers to the same question because answers can be subjective)
dataset[0]

In [None]:
# Here is the image corresponding to the example above. What label would you have given for the question?
image = Image.open(dataset[0]['image_id'])
image

In [None]:
# Remove the rest of the features as they won't be necessary for this task
dataset = dataset.remove_columns(['question_type', 'question_id', 'answer_type'])

In [None]:
# Due to Q&As' ambiguity, datasets like this are treated as a multi-label classification problem
# Moreover, rather than just creating a one-hot encoded vector, one creates a soft encoding
# Soft encoding based on the number of times a certain answer appeared in the annotations
labels = [item['ids'] for item in dataset['label']]
flattened_labels = list(itertools.chain(*labels))
unique_labels = list(set(flattened_labels))

# To later instantiate the model with an appropriate classification head, create two dictionaries
# One dictionary maps the label name to an integer, and the other reverses this mapping
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

In [None]:
# Now that we have the mappings, we can replace the string answers with their ids
def replace_ids(inputs):
    inputs["label"]["ids"] = [label2id[x] for x in inputs["label"]["ids"]]
    return inputs

dataset = dataset.map(replace_ids)
flat_dataset = dataset.flatten()
flat_dataset.features

# Preprocessing

In [None]:
# Load a ViLT processor to prepare the image and text data
# ViltProcessor wraps a BERT tokenizer and ViLT image processor into a convenient single processor
processor = ViltProcessor.from_pretrained(model_checkpoint)