Visual Question Answering (VQA) is the task of answering open-ended questions based on an image. The input to models supporting this task is typically a combination of an image and a question, and the output is an answer expressed in natural language.

Some noteworthy use case examples for VQA include:
- Accessibility applications for visually impaired individuals.
- Education: posing questions about visual materials presented in lectures or textbooks. VQA can also be utilized in interactive museum exhibits or historical sites.
- Customer service and e-commerce: VQA can enhance user experience by letting users ask questions about products.
- Image retrieval: VQA models can be used to retrieve images with specific characteristics. For example, the user can ask “Is there a dog?” to find all images with dogs from a set of images.

In this guide:
1. Fine-tune a classification VQA model, specifically ViLT, on the Graphcore/vqa dataset.
2. Use your fine-tuned ViLT for inference.
3. Run zero-shot VQA inference with a generative model, like BLIP-2.

A note on ViLT versus some recent VQA models: ViLT model incorporates text embeddings into a Vision Transformer (ViT), allowing it to have a minimal design for Vision-and-Language Pre-training (VLP). This model can be used for several downstream tasks. For the VQA task, a classifier head is placed on top (a linear layer on top of the final hidden state of the [CLS] token) and randomly initialized. Visual Question Answering is thus treated as a classification problem. More recent models, such as BLIP, BLIP-2, and InstructBLIP, treat VQA as a generative task. Later in this guide we illustrate how to use them for zero-shot VQA inference.



# Libraries

In [None]:
pip install -q transformers datasets

In [None]:
import torch
import itertools
from PIL import Image
from datasets import load_dataset
from accelerate.test_utils.testing import get_backend
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from transformers import ViltProcessor, DefaultDataCollator, ViltForQuestionAnswering, TrainingArguments, Trainer, pipeline


In [None]:
# Global vars
MODEL_CHECKPOINT = "dandelin/vilt-b32-mlm"

# Automatically detect the underlying device type (CUDA, CPU, XPU, MPS, etc.)
device, _, _ = get_backend() 

# Load Data

In [None]:
# We'll use a very small sample of the annotated visual question answering Graphcore/vqa dataset
dataset = load_dataset("Graphcore/vqa", split="validation[:200]")
dataset

In [None]:
# Inspect an example
# The features relevant to the task include:
# question: the question to be answered from the image
# image_id: the path to the image the question refers to
# label: the annotations (contains several answers to the same question because answers can be subjective)
dataset[0]

In [None]:
# Here is the image corresponding to the example above. What label would you have given for the question?
image = Image.open(dataset[0]['image_id'])
image

In [None]:
# Remove the rest of the features as they won't be necessary for this task
dataset = dataset.remove_columns(['question_type', 'question_id', 'answer_type'])

In [None]:
# Due to Q&As' ambiguity, datasets like this are treated as a multi-label classification problem
# Moreover, rather than just creating a one-hot encoded vector, one creates a soft encoding
# Soft encoding based on the number of times a certain answer appeared in the annotations
labels = [item['ids'] for item in dataset['label']]
flattened_labels = list(itertools.chain(*labels))
unique_labels = list(set(flattened_labels))

# To later instantiate the model with an appropriate classification head, create two dictionaries
# One dictionary maps the label name to an integer, and the other reverses this mapping
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

In [None]:
# Now that we have the mappings, we can replace the string answers with their ids
def replace_ids(inputs):
    inputs["label"]["ids"] = [label2id[x] for x in inputs["label"]["ids"]]
    return inputs

dataset = dataset.map(replace_ids)
flat_dataset = dataset.flatten()
flat_dataset.features

# Preprocessing

In [None]:
# Load a ViLT processor to prepare the image and text data
# ViltProcessor wraps a BERT tokenizer and ViLT image processor into a convenient single processor
processor = ViltProcessor.from_pretrained(model_checkpoint)

In [None]:
# Function to prepare the target labels such that each element corresponds to a possible answer (label)
# For correct answers, the element holds their respective score (weight)
# For incprrect answers, the element weights are set to zero
def preprocess_data(examples):
    image_paths = examples['image_id']
    images = [Image.open(image_path) for image_path in image_paths]
    texts = examples['question']

    encoding = processor(images, texts, padding="max_length", truncation=True, return_tensors="pt")

    for k, v in encoding.items():
          encoding[k] = v.squeeze()

    targets = []

    for labels, scores in zip(examples['label.ids'], examples['label.weights']):
        target = torch.zeros(len(id2label))

        for label, score in zip(labels, scores):
            target[label] = score

        targets.append(target)

    encoding["labels"] = targets

    return encoding

In [None]:
# Apply function to dataset and remove unwanted columns
cols_to_remove = ['question','question_type',  'question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights']
processed_dataset = flat_dataset.map(preprocess_data, 
                                     batched=True, 
                                     remove_columns=cols_to_remove)
processed_dataset

In [None]:
# Create a batch of examples
data_collator = DefaultDataCollator()

# Training

In [None]:
# Load ViLT with ViltForQuestionAnswering
# Specify the number of labels along with the label mappings
model = ViltForQuestionAnswering.from_pretrained(MODEL_CHECKPOINT, 
                                                 num_labels=len(id2label), 
                                                 id2label=id2label, 
                                                 label2id=label2id)

In [None]:
# Define your training hyperparameters in TrainingArguments
training_args = TrainingArguments(
    output_dir="vqa_vilt_finetuned",
    per_device_train_batch_size=4,
    num_train_epochs=20,
    save_steps=200,
    logging_steps=50,
    learning_rate=5e-5,
    save_total_limit=2,
    remove_unused_columns=False
)

In [None]:
# Pass the training arguments to Trainer along with the model, dataset, processor, and data collator
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=processed_dataset,
    processing_class=processor,
)

In [None]:
# Call train() to finetune your model
trainer.train()

# Inference

In [None]:
# Inference using a pipeline
pipe = pipeline("visual-question-answering", model="vqa_vilt_finetuned")

In [None]:
# Check inference on first example
# Note that the model was trained on only 200 examples so performance won't be optimal
example = dataset[0]
image = Image.open(example['image_id'])
question = example['question']
print(question)
pipe(image, question, top_k=1)

In [None]:
# Inference using manual loop
processor = ViltProcessor.from_pretrained("vqa_vilt_finetuned")

image = Image.open(example['image_id'])
question = example['question']

# prepare inputs
inputs = processor(image, question, return_tensors='pt')

model = ViltForQuestionAnswering.from_pretrained("vqa_vilt_finetuned")

# forward pass
with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits
idx = logits.argmax(-1).item()
print("Predicted answer:", model.config.id2label[idx])

# Zero-shot VQA

The above model treated VQA as a classification task. Some recent models, such as BLIP, BLIP-2, and InstructBLIP approach VQA as a generative task. Let’s take BLIP-2 as an example. It introduced a new visual-language pre-training paradigm in which any combination of pre-trained vision encoder and LLM can be used. This enables achieving state-of-the-art results on multiple visual-language tasks including visual question answering. Let’s illustrate how you can use the BLIP-2 model for VQA.

In [None]:
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
model.to(device)