In [None]:
!python -m spacy download en

In [None]:
import os  # when loading file paths
import pandas as pd  # for lookup in annotation file
import spacy  # for tokenizer
from PIL import Image  # Load img
import statistics
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence  # pad batch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models

In [None]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spacy_eng = spacy.load("en")

# Dataset


We want to convert text -> numerical values
1. We need a Vocabulary mapping each word to a index
2. We need to setup a Pytorch dataset to load the data
3. Setup padding of every batch (all examples should be of same seq_len and setup dataloader)

In [None]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]


class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)

In [None]:
dataset = FlickrDataset(
                        "flickr8k/images",
                        "flickr8k/captions.txt",
                        transform=transforms.Compose([transforms.Resize((356, 356)),
                                                      transforms.RandomCrop((299, 299)),
                                                      transforms.ToTensor(),
                                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
                        )

In [None]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return imgs, targets

In [None]:
pad_idx = dataset.vocab.stoi["<PAD>"]

loader = DataLoader(
                        dataset = dataset,
                        batch_size = 32,
                        num_workers = 2,
                        shuffle = True,
                        pin_memory = True,
                        collate_fn = MyCollate(pad_idx=pad_idx),
                    )

In [None]:
for idx, (imgs, captions) in enumerate(loader):
    print(imgs.shape)
    print(captions.shape)

# Modelling

In [None]:
class CNNtoRNN(nn.Module):
    def __init__(self, vocab_size):
        super(CNNtoRNN, self).__init__()

        # image encoder layers
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, 256) #embedding size = 256
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)


        # language model layers
        self.embed = nn.Embedding(vocab_size, 256) #embedding size
        self.lstm = nn.LSTM(256, 256, 2) #embedding size, hidden size, number of layers = 2
        self.linear = nn.Linear(256, vocab_size) #hidden size

    def forward(self, images, captions):

        # encoding images using inception net
        features = self.inception(images)
        features = self.dropout(self.relu(features))

        # language model (RNN)
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break

        return [vocabulary.itos[idx] for idx in result_caption]

# Training

In [None]:
model = CNNtoRNN(len(dataset.vocab)).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [None]:
# Only finetune the CNN
for name, param in model.encoderCNN.inception.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

In [None]:
model.train()
step = 0

In [None]:
for epoch in range(100):

    for idx, (imgs, captions) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):

        imgs = imgs.to(device)
        captions = captions.to(device)

        outputs = model(imgs, captions[:-1])
        loss = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))

        step += 1

        optimizer.zero_grad()
        loss.backward(loss)
        optimizer.step()

# Inference

In [None]:
model.eval()

transform = transforms.Compose([
                                  transforms.Resize((299, 299)),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                  ])


test_img2 = transform(Image.open("test_examples/child.jpg").convert("RGB")).unsqueeze(0)
print("Example 2 CORRECT: Child holding red frisbee outdoors")
print("Example 2 OUTPUT: "+ " ".join(model.caption_image(test_img2.to(device), dataset.vocab)))