# Multimodal Interpretability

In this notebook, we are going to illustrate how to fine-tune the Vision-and-Language Transformer (ViLT) for visual question answering. This is going to be very similar to how one would fine-tune BERT: one just places a head on top that is randomly initialized, and trains it end-to-end together with a pre-trained base.  Subsequently, we will explore application of the Layer Integrated Gradients method.

* Paper: https://arxiv.org/abs/2102.03334
* ViLT docs: https://huggingface.co/docs/transformers/master/en/model_doc/vilt

We will ask you to answer a couple short-form questions and complete a code block.

## Set-up environment

First, we install the Transformers library.

In [1]:
!pip install -q transformers

Next, we mount Google Drive if using the Colab environment.

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [3]:
!wget http://images.cocodataset.org/zips/val2014.zip


--2025-10-02 21:02:47--  http://images.cocodataset.org/zips/val2014.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.217.46.244, 54.231.165.9, 16.182.70.153, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|52.217.46.244|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6645013297 (6.2G) [application/zip]
Saving to: ‘val2014.zip’


2025-10-02 21:10:29 (13.7 MB/s) - ‘val2014.zip’ saved [6645013297/6645013297]



## Load data

Next, we load the data. The data of VQAv2 can be obtained from the [official website](https://visualqa.org/download.html).

For demonstration purposes, we only download the validation dataset. We download:
* the images (stored in a single folder)
* the questions (stored in a JSON)
* the annotations (stored in a JSON) a.k.a. the answers to the questions.

### Read questions

First, we read the questions.

In [4]:
import json

# Opening JSON file
f = open('/content/drive/MyDrive/ViLT/Datasets/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json')

# Return JSON object as dictionary
data_questions = json.load(f)
print(data_questions.keys())

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/ViLT/Datasets/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json'

Let's see how many questions there are:

In [None]:
questions = data_questions['questions']
print("Number of questions:", len(questions))

That's quite a lot! Let's take a look at the first one:

In [None]:
questions[0]

As we can see, this question is related to an image with a certain ID. How can we find back which image this is? The function below allows to get the ID from a corresponding filename. We'll use it to map between image IDs and their corresponding filenames.

In [None]:
import re
from typing import Optional

filename_re = re.compile(r".*(\d{12})\.((jpg)|(png))")

def id_from_filename(filename: str) -> Optional[int]:
    match = filename_re.fullmatch(filename)
    if match is None:
        return None
    return int(match.group(1))

In [None]:
from os import listdir
from os.path import isfile, join
from tqdm.auto import tqdm

# root at which all images are stored
root = '/content/drive/MyDrive/ViLT/Datasets/VQAv2/val2014'
file_names = [f for f in tqdm(listdir(root)) if isfile(join(root, f))]

We can map a filename to its ID as follows:

In [None]:
id_from_filename('COCO_val2014_000000501080.jpg')

We create 2 dictionaries, one that maps filenames to their IDs and one the other way around:

In [None]:
filename_to_id = {root + "/" + file: id_from_filename(file) for file in file_names}
id_to_filename = {v:k for k,v in filename_to_id.items()}

We can now find back the image to which the question 'Where is he looking?' corresponded:

In [None]:
from PIL import Image

print(f"Size of map: {len(id_to_filename)}")

path = id_to_filename[questions[0]['image_id']]
image = Image.open(path)
image

### Read annotations

Next, let's read the annotations. As we'll see, every image is annotated with multiple possible answers.

In [None]:
import json

# Read annotations
f = open('/content/drive/MyDrive/ViLT/Datasets/VQAv2/v2_mscoco_val2014_annotations.json')

# Return JSON object as dictionary
data_annotations = json.load(f)
print(data_annotations.keys())

As we can see, there are 214354 annotations in total (for the validation dataset only!).

In [None]:
annotations = data_annotations['annotations']

In [None]:
print("Number of annotations:", len(annotations))

Let's take a look at the first one. As we can see, the example contains several answers (collected by different human annotators). The answer to a question can be a bit subjective: for instance for the question "where is he looking?", some people annotated this with "down", others with "table", another one with "skateboard", etc. So there's a bit of disambiguity among the annotators :)

In [None]:
annotations[0]

## Add labels + scores

Due to this ambiguity, most authors treat the VQAv2 dataset as a multi-label classification problem (as multiple answers are possibly valid). Moreover, rather than just creating a one-hot encoded vector, one creates a soft encoding, based on the number of times a certain answer appeared in the annotations.

For instance, in the example above, the answer "down" seems to be selected way more often than "skateboard". Hence, we want the model to give more emphasis on "down" than on "skateboard". We can achieve this by giving a score of 1.0 to labels which are counted at least 3 times, and a score < 1.0 for labels that are counted less.

We'll add 2 keys to each annotations:
* labels, which is a list of integer indices of the labels that apply to a given image + question.
* scores, which are the corresponding scores (between 0 and 1), which indicate the importance of each label.

As we'll need the id2label mapping from the VQA dataset, we load it from the hub as follows:

In [None]:
from transformers import ViltConfig

config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

In [None]:
from tqdm.notebook import tqdm

def get_score(count: int) -> float:
    return min(1.0, count / 3)

for annotation in tqdm(annotations):
    answers = annotation['answers']
    answer_count = {}
    for answer in answers:
        answer_ = answer["answer"]
        answer_count[answer_] = answer_count.get(answer_, 0) + 1
    labels = []
    scores = []
    for answer in answer_count:
        if answer not in list(config.label2id.keys()):
            continue
        labels.append(config.label2id[answer])
        score = get_score(answer_count[answer])
        scores.append(score)
    annotation['labels'] = labels
    annotation['scores'] = scores

Let's verify an example:

In [None]:
annotations[0]

Let's verify the labels and corresponding scores:

In [None]:
labels = annotations[0]['labels']
print([config.id2label[label] for label in labels])

In [None]:
scores = annotations[0]['scores']
print(scores)

## Create PyTorch dataset

Next, we create a regular [PyTorch dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). We leverage `ViltProcessor` to prepare each image + text pair for the model, which will automatically:
* leverage `BertTokenizerFast` to tokenize the text and create `input_ids`, `attention_mask` and `token_type_ids`
* leverage `ViltImageProcessor` to resize + normalize the image and create `pixel_values` and `pixel_mask`.

Note that the docs of `ViltProcessor` can be found [here](https://huggingface.co/docs/transformers/master/en/model_doc/vilt#transformers.ViltProcessor).

We also add the labels. This is a PyTorch tensor of shape `(num_labels,)` that contains the soft encoded vector.

In [None]:
import torch
from PIL import Image

class VQADataset(torch.utils.data.Dataset):
    """VQA (v2) dataset."""

    def __init__(self, questions, annotations, processor):
        self.questions = questions
        self.annotations = annotations
        self.processor = processor

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # get image + text
        annotation = self.annotations[idx]
        questions = self.questions[idx]
        image = Image.open(id_to_filename[annotation['image_id']])
        text = questions['question']

        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        # remove batch dimension
        for k,v in encoding.items():
          encoding[k] = v.squeeze()
        # add labels
        labels = annotation['labels']
        scores = annotation['scores']
        targets = torch.zeros(len(config.id2label))
        for label, score in zip(labels, scores):
              targets[label] = score
        encoding["labels"] = targets

        return encoding

In [None]:
from transformers import ViltProcessor

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

dataset = VQADataset(questions=questions[:100],
                     annotations=annotations[:100],
                     processor=processor)

In [None]:
dataset[0].keys()

In [None]:
processor.decode(dataset[0]['input_ids'])

In [None]:
labels = torch.nonzero(dataset[0]['labels']).squeeze().tolist()

In [None]:
[config.id2label[label] for label in labels]

## Define model

Here we define a `ViltForQuestionAnswering` model, with the weights of the body initialized from dandelin/vilt-b32-mlm, and a randomly initialized classification head. We also move it to the GPU, if it's available.

In [None]:
from transformers import ViltForQuestionAnswering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-mlm",
                                                 id2label=config.id2label,
                                                 label2id=config.label2id)
model.to(device)

Next, we create a corresponding PyTorch DataLoader, which allows us to iterate over the dataset in batches.

Due to the fact that the processor resizes images to not necessarily the same size, we leverage the `pad_and_create_pixel_mask` method of the processor to pad the pixel values of a batch and create a corresponding pixel mask, which is a tensor of shape (batch_size, height, width) indicating which pixels are real (1) and which are padding (0).

In [None]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.image_processor.pad(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)

**Question**: What is the purpose of the collation function?

**Answer**: *TODO*

Let's verify a batch:

In [None]:
batch = next(iter(train_dataloader))

In [None]:
for k,v in batch.items():
  print(k, v.shape)

We can verify a given training example, by visualizing the image:

In [None]:
from PIL import Image
import numpy as np

image_mean = processor.image_processor.image_mean
image_std = processor.image_processor.image_std

batch_idx = 1

unnormalized_image = (batch["pixel_values"][batch_idx].numpy() * np.array(image_mean)[:, None, None]) + np.array(image_std)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
Image.fromarray(unnormalized_image)

In [None]:
processor.decode(batch["input_ids"][batch_idx])

In [None]:
labels = torch.nonzero(batch['labels'][batch_idx]).squeeze().tolist()

In [None]:
[config.id2label[label] for label in labels]

## Train a model

Finally, let's train a model! Note that I haven't done any hyperparameter tuning as this notebook was just created for demo purposes. I'd recommend going over the [ViLT paper](https://arxiv.org/abs/2102.03334) for better training settings. You may also wish to use PyTorch Lightning for real training setups.

I just wanted to illustrate that you can make the model overfit this dataset.

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(25):  # loop over the dataset multiple times
   print(f"Epoch: {epoch}")
   for batch in tqdm(train_dataloader):
        # get the inputs;
        batch = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()

## Inference

Let's verify whether the model has actually learned something:

In [None]:
example = dataset[0]
print(example.keys())

In [None]:
processor.decode(example['input_ids'])

In [None]:
# add batch dimension + move to GPU
example = {k: v.unsqueeze(0).to(device) for k,v in example.items()}

# forward pass
outputs = model(**example)

Note that we need to apply a sigmoid activation on the logits since the model is trained using binary cross-entropy loss (as it frames VQA as a multi-label classification task).

In [None]:
logits = outputs.logits
predicted_classes = torch.sigmoid(logits)

probs, classes = torch.topk(predicted_classes, 5)
probs
for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
  print(prob, model.config.id2label[class_idx])

As you can see, the model requires more work to be effective. This is expected since the training data was not even used.

However, this initial checkpoint is sufficient for validating the end-to-end pipeline.

## Interpretability
This section is adapted from the following tutorial: https://captum.ai/tutorials/Multimodal_VQA_Interpret

In [None]:
pip install captum

In [None]:
import copy, os, sys

# Replace <PROJECT-DIR> placeholder with your project directory path
PROJECT_DIR = '/content'

In [None]:
import threading
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F


import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.colors import LinearSegmentedColormap

from captum.attr import (
    IntegratedGradients,
    LayerIntegratedGradients,
    TokenReferenceBase,
    configure_interpretable_embedding_layer,
    remove_interpretable_embedding_layer,
    visualization
)
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper

In [None]:
model.eval()

Download example images [here](https://github.com/pytorch/captum/blob/master/tutorials/img/vqa/), place them in `img/vqa` directory.

In order to explain text features, we must let integrated gradients attribute on the embeddings, not the indices. The reason for this is simply due to Integrated Gradients being a gradient-based attribution method, as we are unable to compute gradients with respect to integers.

Hence, we have two options:
1. "Patch" the model's embedding layer and corresponding inputs. To patch the layer, use the `configure_interpretable_embedding_layer`^ method, which will wrap the associated layer you give it, with an identity function. This identity function accepts an embedding and outputs an embedding. You can patch the inputs, i.e. obtain the embedding for a set of indices, with `model.wrapped_layer.indices_to_embeddings(indices)`.
2. Use the equivalent layer attribution algorithm (`LayerIntegratedGradients` in our case) with the utility class `ModelInputWrapper`. The `ModelInputWrapper` will wrap your model and feed all it's inputs to seperate layers; allowing you to use layer attribution methods on inputs. You can access the associated layer for input named `"foo"` via the `ModuleDict`: `wrapped_model.input_maps["foo"]`.

^ NOTE: For option (1), after finishing interpretation it is important to call `remove_interpretable_embedding_layer` which removes the Interpretable Embedding Layer that we added for interpretation purposes and sets the original embedding layer back in the model.

It is recommended to do option (2) since this option is much more flexible and easy to use. The reason it is more flexible is it allows your model to do any sort of preprocessing to the indices tensor. It's easier to use since you don't have to touch your inputs.

See documentation on `LayerIntegratedGradients`:
https://captum.ai/api/layer.html#layer-integrated-gradients

This wrapper is useful for dealing with the dictionary arguments.

In [None]:
class ViltModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, pixel_values, attention_mask=None, pixel_mask=None, token_type_ids=None):
        # Construct the dictionary expected by the original model's forward method
        inputs = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
            "attention_mask": attention_mask,
            "pixel_mask": pixel_mask,
            "token_type_ids": token_type_ids,
        }
        # Filter out None values
        inputs = {k: v for k, v in inputs.items() if v is not None}
        outputs = self.model(**inputs)
        # Return logits directly
        return outputs.logits

# model wrapper to handle dictionary args
wrap_model = ViltModelWrapper(model)

# wrap the inputs into layers for attribution
m = ModelInputWrapper(wrap_model)

**Implementation**:

Here you should create a `LayerIntegratedGradients` object using the wrapped module `m`.

Try reading the documentation, the linked tutorial, and inspecting the wrapped module `m` for help.

In [None]:
# TODO: create layer integrated gradients object
# attr = LayerIntegratedGradients(...)

Defining default cmap that will be used for image visualizations

In [None]:
default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                 [(0, '#ffffff'),
                                                  (0.25, '#252b36'),
                                                  (1, '#000000')], N=256)

Defining a few test images for model intepretation purposes

In [None]:
images = ['/content/drive/MyDrive/img/vqa/siamese.jpg',
          '/content/drive/MyDrive/img/vqa/elephant.jpg',
          '/content/drive/MyDrive/img/vqa/zebra.jpg']

In [None]:
def vilt_interpret(image_filename, questions, targets):
    img = Image.open(image_filename)

    for question, target in zip(questions, targets):
        inputs = processor(img, question, padding="max_length", truncation=True, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

        additional_args = (
            inputs.get('attention_mask').to(device),
            inputs.get('pixel_mask').to(device),
            inputs.get('token_type_ids').to(device)
        )

        image_baseline = torch.zeros_like(inputs["pixel_values"])
        text_baseline = torch.full_like(inputs["input_ids"], processor.tokenizer.pad_token_id)

        logits = m(**inputs)
        predictions = torch.sigmoid(logits)
        probs, classes = predictions.topk(1, dim=1)
        predicted_class = classes.item()
        predicted_prob = probs.item()
        true_class = config.label2id[target]

        attributions = attr.attribute(inputs=(inputs["input_ids"], inputs["pixel_values"]),
                                      baselines=(text_baseline, image_baseline),
                                      target=true_class,   # Use true class as target
                                      additional_forward_args=additional_args,
                                      n_steps=30)

        # Normalize text attributions.
        text_attributions = attributions[1].sum(dim=2).squeeze(0)
        text_attributions_norm = text_attributions / text_attributions.norm()
        tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

        # Visualize text attributions
        vis_data_record = visualization.VisualizationDataRecord(
            word_attributions=text_attributions_norm.tolist(),
            pred_prob=predicted_prob,
            pred_class=config.id2label[predicted_class],
            true_class=config.id2label[true_class],  # Adjust if you have the true class available
            attr_class=target,  # The class for which attributions were computed
            attr_score=attributions[1].sum().item(),
            raw_input_ids=tokens,
            convergence_score=0.0
        )
        visualization.visualize_text([vis_data_record])

        # Visualize image attributions
        original_im_mat = np.transpose(inputs['pixel_values'].squeeze(0).cpu().detach().numpy(), (1, 2, 0))

        # Get mean and std for image processor
        image_mean = np.array(processor.image_processor.image_mean)
        image_std = np.array(processor.image_processor.image_std)

        # Denormalize using the provided mean and std
        original_im_mat = (original_im_mat * image_std) + image_mean  # Apply denormalization
        original_im_mat = np.clip(original_im_mat, 0, 1)  # Ensure values are in [0, 1] range
        original_im_mat = (original_im_mat * 255).astype(np.uint8)  # Scale to [0, 255] range and convert to integers

        # Reshape attributions tensor
        attributions_img = np.transpose(attributions[0].squeeze(0).cpu().detach().numpy(), (1, 2, 0))

        visualization.visualize_image_attr_multiple(attributions_img, original_im_mat,
                                                    ["original_image", "heat_map"], ["all", "absolute_value"],
                                                    titles=["Original Image", "Attribution Magnitude"],
                                                    cmap=default_cmap,
                                                    show_colorbar=True)

        print('Image Contributions: ', attributions[0].sum().item())
        print('Text Contributions: ', attributions[1].sum().item())
        print('Total Contribution: ', attributions[0].sum().item() + attributions[1].sum().item())

**Question**: What is the purpose of the baseline inputs?

**Answer**: *TODO*

Now generate visualizations for each of the example images.

In [None]:
# the index of image in the test set. Please, change it if you want to play with different test images/samples.
image_idx = 1 # elephant
vilt_interpret(images[image_idx], [
    "what is on the picture",
    "what color is the elephant",
    "where is the elephant"
], ['elephant', 'gray', 'zoo'])

In [None]:
image_idx = 0 # cat

vilt_interpret(images[image_idx], [
    "what is on the picture",
    "what color are the cat's eyes",
    "is the animal in the picture a cat or a fox",
    "what color is the cat",
    "how many ears does the cat have",
], ['cat', 'blue', 'cat', 'white and brown', '2'])

In [None]:
image_idx = 2 # zebra

vilt_interpret(images[image_idx], [
    "what is on the picture",
    "what color are the zebras",
    "how many zebras are on the picture",
    "where are the zebras"
], ['zebra', 'black and white', '2', 'zoo'])