*Insert here your name*

For this notebook a GPU environment is recomended, you could use the Nvidia T4 GPU provided by the Google Colab free plan; keep in mind terms of service and limitations though.

## Image captioning

In this exercise we will experiment with something a little bit different, moving beyond NLP and tryining to build Visual Language Models. The goal indeed is to automatically produce image captions. Similar to a sequence-to-sequence machine translation architectures, we will rely on a encoder-decoder model to 1) encode the image and 2) generate the caption. This time, though, we have multi-modal data: images on one side and texts on the other. These two sources of information will be combined by cross-attention mechanism inside the decoder.

Once the model is trained we can generate the textual description and visualize the attention weights. Let's begin!



In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from collections import Counter
from tqdm import tqdm

import torch
from torch import Tensor
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from itertools import chain

We use the Flickr 8k Dataset, a popular dataset containing 8k images (taken from Flickr) with 5 captions each.

WARNING: the dataset size is 1.04G

In [None]:
!pip install kagglehub

In [None]:
import kagglehub

PATH = kagglehub.dataset_download("adityajn105/flickr8k")
print("Path to dataset files:", PATH)

In the downloaded folder you will see `captions.txt`, the collection of captions and the corresponding image filenames, and a `Ìmages` subfolder containing the jpg images.

Let's load the `captions.txt` file first:

In [None]:
data = pd.read_csv(os.path.join(PATH, 'captions.txt'), sep=',', header=0)
data.head()

In the 'image' column we only have the image filename, that's why we need to concatenate PATH and the filename to be able to retrieve the image.

We also need to tokenize the text in the column 'caption' to get the tokens.

Let's keep things simple for now and consider each word as a single token

In [None]:
data['image'] = data['image'].apply(lambda x: os.path.join(PATH, 'Images', x))

# Your code here
# create a new column 'tokens' containing lowercased tokens from the caption



assert(data['tokens'][0][:6] == ['a','child','in','a','pink','dress'])

For simplicity we will limit the dataset by reducing the number of captions, keeping only a single caption for each image

In [6]:
images = [x for i, x in enumerate(data['image']) if i % 5 == 0]
sentences = [x for i, x in enumerate(data['tokens']) if i % 5 == 0]

Being this a very close-ended generation task, we will derive our (small) vocabulary directly from the dataset by considering only words that appear at least 5 times.

With the `token2index` we can convert tokens (words) into indeces, and the way back with `index2token`

In [None]:
class Vocab:
    def __init__(self, dataset):
        self.word_freq = Counter()
        for tokens in dataset:
            self.word_freq.update(tokens)

        words = [word for word in self.word_freq.keys() if self.word_freq[word] >= 5]
        self.token2index = {word: idx for idx, word in enumerate(words, 4)}
        self.token2index['<pad>'] = 0
        self.token2index['<start>'] = 1
        self.token2index['<eos>'] = 2
        self.token2index['<unk>'] = 3
        self.index2token = {value: word for word, value in self.token2index.items()}
        self.vocab_size = len(self.index2token)

    def __len__(self):
        return self.vocab_size

vocab = Vocab(sentences)
len(vocab)

We limit the sequence length to 50, truncating or padding the input if needed

In the following code we add the special token `<start>` at the beginning of the sentence and `<eos>` at the end, `<unk>` is used to represent words not included in the dictionary (vocabulary)

In [8]:
MAX_LEN = 50

captions = []
for sent in sentences:
    sent = sent[:MAX_LEN-2]
    toks = [vocab.token2index['<start>']] + \
            [vocab.token2index.get(x, vocab.token2index['<unk>']) for x in sent] + \
            [vocab.token2index['<eos>']] + \
            [vocab.token2index['<pad>']] * (MAX_LEN-2 - len(sent))
    captions.append(toks)

We can finally split the preprocessed dataset into train, validation and test set

In [None]:
train_images = images[:-300]
train_captions = captions[:-300]

eval_images = images[-300:-100]
eval_captions = captions[-300:-100]

test_images = images[-100:]
test_captions = captions[-100:]

print('Training set ',len(train_images))
print('Validation set ',len(eval_images))
print('Test set ',len(test_images))

And finally create our customized dataset.

With this code we convert RGB images into Tensors and apply normalization. For the textual part instead we convert tokens into Long Tensors, create the padding mask and the labels used for the decoder training. In this case the labels will just be a copy of the input text, in which we replace padded tokens with -100 (they will be ignored by the loss function).

In [10]:
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])

class CaptionDataset(Dataset):
    def __init__(self, images, captions, transform=None):
        self.images = images
        self.captions = captions
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('RGB')
        img = np.array(img)
        img = cv2.resize(img, (256, 256))

        assert img.shape == (256, 256, 3)
        assert np.max(img) <= 255.

        img = torch.FloatTensor(img / 255.)
        img = img.permute(2, 0, 1)  # 3, H, W

        if self.transform is not None:
            img = self.transform(img)

        caption = torch.LongTensor(self.captions[idx])

        mask = caption == 0

        labels = caption.clone() * ~mask
        labels += torch.tensor([-100], dtype=torch.long) * mask

        return img, caption, labels, mask

train_dataset = CaptionDataset(train_images, train_captions, transform)
eval_dataset = CaptionDataset(eval_images, eval_captions, transform)
test_dataset = CaptionDataset(test_images, test_captions, transform)

Let's see an example

In [None]:
def decode_caption(tokens, vocab):
    dec_caption = [vocab.index2token[x] for x in tokens.numpy() if x != 0]
    return " ".join(dec_caption)

for image, caption, _, _ in eval_dataset:
    print(decode_caption(caption, vocab))
    plt.imshow(image.permute(1, 2, 0))
    plt.show()
    break

Clearly the normalized image is not visualized correctly because of the invalid values range. The caption instead is encapsulated between `<start>` and `<eos>` with `<unk>` replacing rare words.

Let's create a configuration class for our model and set other hyper-parameters

In [12]:
class Config:
    embd_pdrop = 0.2
    resid_pdrop = 0.2
    attn_pdrop = 0.2
    head_pdrop = 0.2
    n_layer = 6
    n_head = 8
    n_embd = 1024
    max_len = MAX_LEN

    def __init__(self, vocab_size):
        setattr(self, 'vocab_size', vocab_size)

config = Config(len(vocab))

BATCH_SIZE = 64
LR = 1e-4
EPOCHS = 4

Now we can create our data loaders, please notice the batch size for the test set to 1, since at inference time we will generate captions one by one.

In [13]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
test_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False, pin_memory=True)

It's time now to start implementing the model architecture, in particular the image encoder.

For this purpose, we rely on a pre-trained ResNet-50 model (a CNN), finetuning only a subset of layers (if needed)

The encoder will convert the (N, 3, H, W) RGB image into a (N, 14x14, 2048) tensor, where N is the batch size and 2048 the feature dimension

In [14]:
class ImageEncoder(nn.Module):
    def __init__(self, enc_image_size=14):
        super(ImageEncoder, self).__init__()
        self.enc_image_size = enc_image_size

        # pretrained ImageNet ResNet-50
        resnet = torchvision.models.resnet50(pretrained=True)

        # remove linear and pool layers
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        # resize image to fixed size using adaptive pool to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((enc_image_size, enc_image_size))
        self.fine_tune()

    def fine_tune(self):
        for param in self.resnet.parameters():
            param.requires_grad = False

        # if fine-tuning, fine-tune only convolutional blocks 2 through 4
        for child in list(self.resnet.children())[5:]:
            for param in child.parameters():
                param.requires_grad = True

    def forward(self, images):
        out = self.resnet(images)
        out = self.adaptive_pool(out)  # N, 2048, 14, 14

        batch, channel = out.size(0), out.size(1)
        out = out.view(batch, channel, -1).permute(0, 2, 1)  # N, 14x14, 2048
        return out

Differently from the Transformer encoder, for the decoder we need to 1) apply casual masking to prevent the attention to look at future tokens and 2) add a cross-attention mechanism to combine images and texts

Look at the official documentation to complete this part [https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)

In [None]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm1 = nn.LayerNorm(config.n_embd)
        self.norm2 = nn.LayerNorm(config.n_embd)

        self.self_attn = nn.MultiheadAttention(config.n_embd, config.n_head, batch_first=True)
        self.self_attn_drop = nn.Dropout(config.attn_pdrop)
        self.cross_attn = nn.MultiheadAttention(config.n_embd, config.n_head, kdim=2048, vdim=2048, batch_first=True)
        self.cross_attn_drop = nn.Dropout(config.attn_pdrop)

        self.linear1 = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.linear_drop = nn.Dropout(config.resid_pdrop)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(4 * config.n_embd, config.n_embd)
        self.linear_drop2 = nn.Dropout(config.resid_pdrop)

    def forward(self, x, mem, att_msk, pad_mask):
        '''
        :param x: textual embedding
        :param mem: visual embedding
        :param att_msk: casual mask
        :param pad_mask: padding mask
        '''

        x = self.norm1(x + self._sa_block(x, att_msk, pad_mask))

        enc_att, att_weight = self._mha_block(x, mem)
        x = self.norm1(x + enc_att)

        x = self.norm2(x + self._ff_block(x))
        return x, att_weight

    def _sa_block(self, x, att_mask, pad_mask):
        '''
        self-attention
        '''

        # Your code here
        # call self_attn with the correct parameters, keeping in mind that it's a self-attention
        # don't forget to use the casual attention mask and padding masks, as well as setting is_casual correctly
        

        return self.self_attn_drop(x)

    def _mha_block(self, x, mem):
        '''
        cross-attention
        '''

        # Your code here
        # call cross_attn with the correct parameters
        # please return the attention weights (att_weight) too
        
        
        return self.cross_attn_drop(x), att_weight

    def _ff_block(self, x: Tensor) -> Tensor:
        '''
        Feed forward network
        '''

        x = self.linear2(self.linear_drop(self.activation(self.linear1(x))))
        return self.linear_drop2(x)

**Why do we need to specify `kdim` and `vdim` instead of relying on the deault values (`dim=embed_dim`)?**

*Your answer here*

In this block we assumed that we already have the casual attention mask to prevent the attention to see future tokens in the sequence, it's time now to create this mask

As we have seen, when applying casual attention, we prevent the tokens to attend to the future tokens in the sequence, only looking at previous ones.
We can paralelize the computation by creating a triangual table like this:

<table >
<tr><td>THIS<td><td><td><tr>
<tr><td>THIS<td>IS<td><td><tr>
<tr><td>THIS<td>IS<td>A<td><tr>
<tr><td>THIS<td>IS<td>A<td>TEST<tr>
<table>

and the corresponding attention mask:

<table >
<tr><td>0<td>1<td>1<td>1<tr>
<tr><td>0<td>0<td>1<td>1<tr>
<tr><td>0<td>0<td>0<td>1<tr>
<tr><td>0<td>0<td>0<td>0<tr>
<table>

in which ones represent (masked) positions that we should ignore when computing the attention

Let's start with an example

In [None]:
H, W = 4, 4

# Your code here
# 1) create a HxW triangual matrix like in the example
# 2) convert the matrix into a boolean matrix (True = masked token)


print(mask)
assert(torch.equal(mask,torch.tensor([[False,  True,  True,  True],[False, False,  True,  True],[False, False, False,  True],[False, False, False, False]])))

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])


We now have everything we need for the Decoder

One last thing though! Instead of a classification head, this time we need a language model head on which we can then compute the probability on the vocabulary. Rather than having few classes, we want to predict each word in the vocabulary, meaning that we need a final tensor with shape (L, N, V), where L is the sequence lenght, N the batch size and V the vocabulary size.

Let's project the (L, N, E) embedding (with E the embedding size) into (L, N, V) with a Linear layer without bias

In [None]:
# Your code here
# create the language model head with a linear layer without the bias term
# use the config class we create before to get the correct parameters for the network

lm_head = 

test = torch.rand(100, 4, 1024)
out = lm_head(test)
assert(tuple(out.shape) == (100, 4, 1205))
assert('bias' not in {x for x, _ in lm_head.named_parameters()})

In [None]:
class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()
        # embedding layer
        self.emb = nn.Embedding(config.vocab_size, config.n_embd)

        # positional embedding layer
        self.pos_emb = nn.Embedding(config.max_len, config.n_embd)

        self.norm_emb = nn.LayerNorm(config.n_embd)
        self.drop_emb = nn.Dropout(config.embd_pdrop)

        self.blocks = nn.ModuleList(Block(config) for _ in range(config.n_layer))

        # Your code here
        # add your language model head (same as before)
        self.lm_head = 

        self.drop_head = nn.Dropout(config.head_pdrop)

        self._init_weights()

    def _init_weights(self):
        self.emb.weight.data.uniform_(-0.1, 0.1)
        self.lm_head.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input_ids, img, pad_msk):
        device = input_ids.device

        # Your code here
        # create the boolean attention mask with LxL shape (where L is the input sequence lenght)
        # don't forget to load it on the same device as the data and the model!
        # Hint: take the correct dimension directly from the input_ids shape

        mask = 

        x = self.emb(input_ids)  # N, L, E
        position_ids = torch.arange(x.size(1), dtype=torch.long).unsqueeze(0).to(device)  # 1, L
        pos_emb = self.pos_emb(position_ids)  # 1, L, E
        x = x + pos_emb  # N, L, E
        x = self.drop_emb(self.norm_emb(x))

        for block in self.blocks:
            x, att_weights = block(x, img, mask, pad_msk)

        lm_logits = self.lm_head(self.drop_head(x))  # L, N, V
        return lm_logits, att_weights

We are finally ready to initialize the model, optimizer and loss

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

encoder = ImageEncoder()
encoder = encoder.to(device)

decoder = Decoder(config)
decoder = decoder.to(device)

params = chain(encoder.parameters(), decoder.parameters())
optim = AdamW([x for x in params if x.requires_grad], lr=LR)
criterion = nn.CrossEntropyLoss()

We train our model by comparing the generated sequence with the true label by shifting it, that is:
<table>
<tr><td>LABEL<td><td>this<td>is<td>a<td>...<tr>
<tr><td>OUTPUT<td>this<td>is<td>a<td>test<td>...<tr>
<table>


In other words, we always predict the next word in the sequence, therefore to check if the prediction if correct or not it's sufficient to create a shifted copy of the input. Keep in mind, indeed, that until now `labels`was just a clone of the input tokens, we now need to flatten the representation (taking every sequence inside the batch) and apply the shift

The loss will be a simple cross entropy betweend predictions and labels

Here our training and validation loops:

In [29]:
def train(data_loader, encoder, decoder, optimizer, criterion, device):
    encoder = encoder.train()
    decoder = decoder.train()
    loss_epoch = 0

    for img, cap, lab, msk in tqdm(data_loader):
        img = img.to(device)
        cap = cap.to(device)
        lab = lab.to(device)
        msk = msk.to(device)
        enc_img = encoder(img)
        lm_logits, _ = decoder(cap, enc_img, msk)

        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        shift_labels = lab[..., 1:].contiguous()
        shift_labels = shift_labels.view(-1)
        loss = criterion(shift_logits, shift_labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_epoch += loss.item()

    return loss_epoch / len(data_loader)


def validate(data_loader, encoder, decoder, criterion, device):
    encoder = encoder.eval()
    decoder = decoder.eval()
    loss_epoch = 0

    with torch.no_grad():
        for img, cap, lab, msk in tqdm(data_loader):
            img = img.to(device)
            cap = cap.to(device)
            lab = lab.to(device)
            msk = msk.to(device)
            enc_img = encoder(img)
            lm_logits, _ = decoder(cap, enc_img, msk)

            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = lab[..., 1:].contiguous()
            shift_labels = shift_labels.view(-1)
            loss = criterion(shift_logits, shift_labels)
            loss_epoch += loss.item()

    return loss_epoch / len(data_loader)

Finally we can launch the training, but it could take a while: around 4 min per epoch

In [None]:
train_loss, valid_loss = [], []
least_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"[INFO]: Epoch {epoch + 1} of {EPOCHS}")
    train_epoch_loss = train(train_loader, encoder, decoder, optim, criterion, device)
    valid_epoch_loss = validate(eval_loader, encoder, decoder, criterion, device)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    print(f"Training loss: {train_epoch_loss}")
    print(f"Validation loss: {valid_epoch_loss}")
    print('-' * 50)

Let's plot the loss

In [None]:
x = [x for x in range(1, EPOCHS + 1)]
plt.plot(x, train_loss, label="train")
plt.plot(x, valid_loss, label="eval")
plt.legend()
plt.show()

Now that the model is trained, let's try to use it for inference, that is generating new captions on the test set.

We start with an empy sentence (with only the `<start>` token) and greedly select the most probable next word at every step, terminating when we generate the `<eos>` terminal token, or when we reach the `MAX_LEN` sequence lenght.

In [None]:
encoder.eval()
decoder.eval()
with torch.no_grad():
    for img, _, _, _ in tqdm(test_loader):
        img = img.to(device)
        enc_img = encoder(img)

        # Your code here
        # create a Tensor containing the token index corresponding to <start>
        # the tensor, with Long type, should have shape (1, S), where 1 is the batch size and S the sequence lenght (1 in this case)
        # keep in mind the tensor must be loaded into the same device of the model
        # Hint: use vocab to make the conversion from string to token id
        input_ids = 

        for _ in range(config.max_len):
            lm_logits, att_weights = decoder(input_ids, enc_img, None)

            # lm_logits shape is (1, S, V) with S the sequence length and V the vocabulary size
            next_item = lm_logits[0, -1].topk(1)[1] # get the top-1 probability (most probable next word)
            next_item = next_item.unsqueeze(0) # (1, 1)

            # Your code here
            # keep generating the sequence by concatenating the predicted word to the input sequence (overwriting input_ids)
            input_ids = 

            # check is the predicted token is the end of the sequence
            if next_item.item() == vocab.token2index['<eos>']:
                break

        tokens = input_ids[0].cpu().tolist()

        # Your code here
        # 1. convert the list of token indeces into token strings using the vocabulary
        # 2. join the tokens to form a sentence and print the result
        
        

        print(sentence)
        break

We can also plot the attention weights returned by the decoder and the original image, we can use the following function to reverse the normalization process

In [None]:
def unormalize(tensor, mean=None, std=None):
    if mean is not None and std is not None:
        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)
        return torch.clip(tensor, min=0, max=1)

    b, c, h, w = tensor.shape
    tensor = tensor.view(b, -1)
    tensor -= tensor.min(1, keepdim=True)[0]
    tensor /= tensor.max(1, keepdim=True)[0]
    return tensor.view(b, c, h, w)

img = unormalize(img)[0].permute(1, 2, 0).cpu()  # 256, 256, 3
plt.imshow(img, interpolation='nearest')

Let's put everything together and run our inference

In [None]:
encoder.eval()
decoder.eval()
with torch.no_grad():
    for img, _, _, _ in tqdm(test_loader):
        img = img.to(device)
        enc_img = encoder(img)

        # Your code here
        # add here your code for generating the sentence

        

        
        print(sentence)

        img = unormalize(img)[0].permute(1, 2, 0).cpu()  # 256, 256, 3

        transform = T.Resize(size=(img.size(0), img.size(1)))  # h, w
        att = att_weights[0].reshape(-1, 14, 14).cpu()  # 100, 196
        att = transform(att).permute(1, 2, 0)  # h,w,d
        att = torch.sum(att, 2) / att.shape[2]

        plt.imshow(img, interpolation='nearest')
        plt.imshow(att, interpolation='bilinear', alpha=0.5)
        plt.show()
        plt.close()

Besides the predicted caption we plot here the attention weights from the last layer superimposed on the original image.

Despite few epochs and a small architecture, the model seems to be able to describe to some extent the images, focusing the attention where it's most needed.