In [None]:
!pip install -qU transformers datasets evaluate jiwer

# Image Captioning

**Image captioning** is the task of predicting a caption for a given image. It can help to improve content accessibility for people by describing images to them.

## Load the Pokemon BLIP captions dataset

Load a dataset that consists of `{image-caption}` pairs.

In [None]:
from datasets import load_dataset

ds = load_dataset('lambdalabs/pokemon-blip-captions')

In [None]:
ds

The dataset has two features, `image` and `text`.

Many image captioning datasets contain multiple captions per image. In this case, a common strategy is to randomly sample a caption amongst the available ones during training.

In [None]:
ds = ds['train'].train_test_split(test_size=0.1)
train_ds = ds['train']
test_ds = ds['test']

In [None]:
from textwrap import wrap
import matplotlib.pyplot as plt
import numpy as np


def plot_images(images, captions):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i+1)
        caption = captions[i]
        caption = '\n'.join(wrap(caption, 12))
        plt.title(caption)
        plt.imshow(images[i])
        plt.axis('off')
    plt.show()


sample_images_to_viz = [np.array(train_ds[i]['image']) for i in range(5)]
sample_captions = [train_ds[i]['text'] for i in range(5)]

plot_images(sample_images_to_viz, sample_captions)

## Preprocess the dataset

Since the dataset has two modalities (text and image), the preprocessing pipeline will preprocess the images and the captions.

In [None]:
from transformers import AutoProcessor

# GenarativeImage2Text
checkpoint = 'microsoft/git-base'
processor = AutoProcessor.from_pretrained(checkpoint)

This processor will internally pre-process the image (which including resizing, and pixel scaling) and tokenize the caption.

In [None]:
def transforms(example_batch):
    images = [x for x in example_batch['image']]
    captions = [x for x in example_batch['text']]

    inputs = processor(
        images=images,
        text=captions,
        padding="max_length"
    )
    inputs.update({'labels': inputs['input_ids']})

    return inputs

train_ds.set_transform(transforms)
test_ds.set_transform(transforms)

## Load a base model

Load the `microsoft/git-base` into a `AutoModelForCausalLM`:

In [None]:
from transformers import AutoMOdelForCausalLM

model = AutoModelForCausalLM.from_pretrained(checkpoint)

## Evaluate

Image captioning models are typically evaluated with the **ROUGE score** or **Word Error Rate (WER)**.
* **ROUGE**: Recall-Oriented Understudy for Gisting Evaluation is a set of metrics for evaluating automatic summarization and machine translation in NLP. The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.
* **WER**: Word Error Rate is a common metric of the performance of an automatic speech recognition system.

In [None]:
from evaluate import load
import torch

wer = load('wer')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)

    wer_score = wer.compute(
        predictions=decoded_predictions,
        references=decoded_labels
    )

    return {'wer': wer_score}

## Train

In [None]:
from transformers import TrainingArguments, Trainer

model_name = checkpoint.split('/')[1]
print(model_name)

training_args = TrainingArguments(
    output_dir=f"{model_name}-pokemon",
    learning_rate=5e-5,
    num_train_epochs=50,
    fp16=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=2,
    save_total_limit=3,
    eval_strategy='steps',
    eval_steps=50,
    save_strategy='steps',
    save_steps=50,
    logging_steps=50,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=['labels'],
    load_best_model_at_end=True,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

## Inference

In [None]:
from PIL import Image
import requests

url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)
image = image.convert('RGB')
image

Prepare image for the image

In [None]:
from accelerate.test_utils.testing import get_backend

device, _, _ = get_backend()
inputs = processor(
    image,
    return_tensors='pt'
).to(device)
pixel_values = inputs.pixel_values

Call `generate` and decode the predictions

In [None]:
generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = processor.batch_decode(
    generated_ids,
    skip_special_tokens=True
)
generated_caption