Image captioning is the task of predicting a caption for a given image. Common real world applications of it include aiding visually impaired people that can help them navigate through different situations. Therefore, image captioning helps to improve content accessibility for people by describing images to them.

This guide shows how to:
1. Fine-tune an image captioning model.
2. Use the fine-tuned model for inference.

# Libraries

In [None]:
pip install transformers datasets evaluate -q
pip install jiwer -q

In [None]:
from datasets import load_dataset
from textwrap import wrap
import matplotlib.pyplot as plt
import numpy as np


# Load Data

In [None]:
# Load the Pokémon BLIP captions dataset
# Consists of {image-caption} pairs
ds = load_dataset("lambdalabs/pokemon-blip-captions")

# Inspect data set - note the two features (image and text)
ds

In [None]:
# Split into train and test sets
ds = ds["train"].train_test_split(test_size=0.1)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
# Visualise examples from the training set
def plot_images(images, captions):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        caption = captions[i]
        caption = "\n".join(wrap(caption, 12))
        plt.title(caption)
        plt.imshow(images[i])
        plt.axis("off")


sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)]
sample_captions = [train_ds[i]["text"] for i in range(5)]
plot_images(sample_images_to_visualize, sample_captions)