In [205]:
import torch
import torch.nn as nn
import os

# DataLoaders
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Output
import matplotlib.pyplot as plt
from tqdm import tqdm

In [206]:
# Conveniently store paths to the data.
cwd = os.getcwd()
DATA_DIR = os.path.join(cwd, "data/coco")
ANNOTATIONS_DIR = os.path.join(DATA_DIR, "annotations")

annotations_tokenized_index = {
    "train": os.path.join(ANNOTATIONS_DIR, "train_tokenized.pt"),
    "val": os.path.join(ANNOTATIONS_DIR, "val_tokenized.pt"),
}
images_index = {
    "train": os.path.join(DATA_DIR, "train2017"),
    "val": os.path.join(DATA_DIR, "val2017"),
}

## Define image augmentations.

In [207]:
train_image_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(224, scale=(0.9, 1.0), ratio=(0.9, 1.1)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

val_image_transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

image_transform_index = {"train": train_image_transform, "val": val_image_transform}

## Create DataLoaders.

In [208]:
class CocoDataset(Dataset):

    def __init__(self, images_path, annotations_path, image_transform):
        super().__init__()

        self.images_path = images_path
        self.annotations = torch.load(annotations_path, weights_only=False)

        self.image_transform = image_transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):

        caption = self.annotations[idx]["caption"]
        caption_tensor = torch.tensor(caption, dtype=torch.long)

        image_id = self.annotations[idx]["image_id"]
        image_path = os.path.join(self.images_path, f"{image_id:012}.jpg")
        image = Image.open(image_path).convert("RGB")
        image_transformed = self.image_transform(image)

        return image_transformed, caption_tensor

In [209]:
tokenizer_info = torch.load(
    os.path.join(DATA_DIR, "tokenizer_info.pt"), weights_only=False
)


def coco_collate_fn(batch):

    images = []
    captions = []

    for image, caption in batch:
        images.append(image)
        captions.append(caption)

    images_batch = torch.stack(images, dim=0)
    captions_batch = pad_sequence(
        captions, batch_first=True, padding_value=tokenizer_info["<PAD>"]
    )

    return images_batch, captions_batch

In [210]:
BATCH_SIZE = 32
dataloaders = {}

for split in ["train", "val"]:
    dataset = CocoDataset(
        images_index[split],
        annotations_tokenized_index[split],
        image_transform_index[split],
    )
    dataloaders[split] = DataLoader(
        dataset,
        BATCH_SIZE,
        shuffle=(True if split == "train" else False),
        collate_fn=coco_collate_fn,
        num_workers=10,  # Using an i5 12400f with 6 cores 12 threads. Feel free to change this according to your own setup.
        pin_memory=True,
    )

## Define the model.

In [211]:
class PatchEmbedding(nn.Module):

    def __init__(self, d_model, image_size, patch_size):
        super().__init__()

        self.image_to_patch_projections = nn.Conv2d(
            3, d_model, kernel_size=patch_size, stride=patch_size
        )

        num_patches = (image_size // patch_size) ** 2
        self.pos_encoding = nn.Parameter(torch.zeros(1, num_patches, d_model))

    def forward(self, x):
        x = self.image_to_patch_projections(x)
        x = x.flatten(-2, -1)
        x = x.transpose(1, 2)
        x = x + self.pos_encoding

        return x

In [None]:
class ImageCaptioner(nn.Module):

    def __init__(
        self,
        image_size,
        patch_size,
        vocab_size,
        max_caption_len,
        d_model,
        nhead,
        dim_feedforward,
        num_layers,
        PAD_IDX,
    ):
        super().__init__()

        # Encoder
        self.patch_embedding = PatchEmbedding(d_model, image_size, patch_size)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model, nhead, dim_feedforward, batch_first=True
            ),
            num_layers,
        )

        # Decoder
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)
        self.tgt_pos_encoding = nn.Parameter(torch.zeros((1, max_caption_len, d_model)))

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model, nhead, dim_feedforward, batch_first=True
            ),
            num_layers,
        )

        self.project_to_vocab_size = nn.Linear(d_model, vocab_size)

        # Store the PAD_IDX for on-the-fly tgt_key_padding_mask creation.
        self.PAD_IDX = PAD_IDX

    def forward(self, images, captions):

        # Encode the images.
        embedded_patches = self.patch_embedding(images)
        encoded_images = self.encoder(embedded_patches)

        # Embed the captions.
        embedded_captions = (
            self.tgt_embedding(captions)
            + self.tgt_pos_encoding[:, : captions.shape[1], :]
        )

        # General tgt_mask which is causal.
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            captions.shape[1], captions.device, torch.bool
        )

        # Generate tgt_key_padding_mask according to the batch of captions.
        tgt_key_padding_mask = captions == self.PAD_IDX

        # Encode the captions.
        encoded_captions = self.decoder(
            tgt=embedded_captions,
            memory=encoded_images,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            tgt_is_causal=True,
        )

        # Project the encoded captions to the size of the vocabulary to get a distribution over the vocabulary for each token position.
        distributions = self.project_to_vocab_size(encoded_captions)

        return distributions

## Train the model.

In [213]:
# Configure the model.
image_size = 224
patch_size = 16
vocab_size = tokenizer_info["vocab_size"]
max_caption_len = 50
d_model = 512
nhead = 8
dim_feedforward = 2048
num_layers = 6
PAD_IDX = tokenizer_info["<PAD>"]


# Set device to gpu if available
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print("You are using device: %s" % device)

You are using device: cuda


In [214]:
model = ImageCaptioner(
    image_size,
    patch_size,
    vocab_size,
    max_caption_len,
    d_model,
    nhead,
    dim_feedforward,
    num_layers,
    PAD_IDX,
)

model = model.to(device)
criterion = nn.CrossEntropyLoss()

In [221]:
batches = iter(dataloaders["train"])
images, captions = next(batches)

# images = images.to(device)
# captions = captions.to(device)

# x_tgt = captions[:,:-1]
# labels = captions[:,1:]

# distributions = model(images, x_tgt)


# loss = criterion(distributions.flatten(0,-2), labels.flatten())
# loss.backward()