Image captioning is the task of predicting a caption for a given image. Common real world applications of it include aiding visually impaired people that can help them navigate through different situations. Therefore, image captioning helps to improve content accessibility for people by describing images to them.

This guide shows how to:
1. Fine-tune an image captioning model.
2. Use the fine-tuned model for inference.

# Libraries

In [None]:
pip install transformers datasets evaluate -q
pip install jiwer -q

In [None]:
import torch
import requests
import numpy as np
from PIL import Image
from textwrap import wrap
from evaluate import load
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForCausalLM, TrainingArguments, Trainer


# Load Data

In [None]:
# Load the Pokémon BLIP captions dataset
# Consists of {image-caption} pairs
ds = load_dataset("lambdalabs/pokemon-blip-captions")

# Inspect data set - note the two features (image and text)
ds

In [None]:
# Split into train and test sets
ds = ds["train"].train_test_split(test_size=0.1)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
# Visualise examples from the training set
def plot_images(images, captions):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        caption = captions[i]
        caption = "\n".join(wrap(caption, 12))
        plt.title(caption)
        plt.imshow(images[i])
        plt.axis("off")


sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)]
sample_captions = [train_ds[i]["text"] for i in range(5)]
plot_images(sample_images_to_visualize, sample_captions)

# Preprocessing

In [None]:
# Load model pre-processor
checkpoint = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint)

In [None]:
# Data set has two modalities (image and text) that need pre-processing
# Write a function to pre-process the image (which includes resizing, and pixel scaling) and tokenize the caption
def transforms(example_batch):
    images = [x for x in example_batch["image"]]
    captions = [x for x in example_batch["text"]]
    inputs = processor(images=images, text=captions, padding="max_length")
    inputs.update({"labels": inputs["input_ids"]})
    return inputs

train_ds.set_transform(transforms)
test_ds.set_transform(transforms)

# Evaluation

In [None]:
# Image captioning models are typically evaluated with the Rouge Score or Word Error Rate
# In this tutorial, we use Word Error Rate for eval

wer = load("wer")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"wer_score": wer_score}

# Training

In [None]:
# Load pre-trained model
model = AutoModelForCausalLM.from_pretrained(checkpoint)

In [None]:
# Define training arguments
model_name = checkpoint.split("/")[1]

training_args = TrainingArguments(
    output_dir=f"{model_name}-pokemon",
    learning_rate=5e-5,
    num_train_epochs=50,
    fp16=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=2,
    save_total_limit=3,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    logging_steps=50,
    remove_unused_columns=False,
    label_names=["labels"],
    load_best_model_at_end=True,
)

# Pass training arguments to the trainer along with data set and model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

# Call trainer to start training/fine-tuning
trainer.train()

# Inference

In [None]:
# Load an image for captioning
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:
# Process image and pass to model
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values

In [None]:
# Call generate() on the model to predictively caption
generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)