#### 0. Import library, and define constants

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import torch.optim as optim
import string
import re
import seaborn as sns
import torchvision
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt

Define some constants for later use

In [None]:
INPUT_IMAGES_DIR = "../datasets/flickr30k/images/"
LABEL_PATH = "../datasets/flickr30k/captions.txt"
OUTPUT_PATH = "../CNN-LSTM/working"

# Some special tokens for RNN model...
UNK = "#UNK"
PAD = "#PAD"
START = "#START"
END = "#END"

Read csv file

In [None]:
df = pd.read_csv(LABEL_PATH, sep="|")

In [None]:
df.head()

#### 1. Data visualization & analysis

Preprocess raw text to:
1. lower case
2. remove punctuation

In [None]:
regex = re.compile('[%s]' % re.escape(string.punctuation))
def clean_text(row):
    row = str(row).strip()
    row = row.lower()
    return regex.sub("", row)

In [None]:
df.columns = [col.strip() for col in df.columns]
df["caption_text"] = df["caption_text"].apply(clean_text)

Compute the length of each sentence and attach it to current dataframe

In [None]:
df["length"] = df["caption_text"].apply(lambda row: len(row.strip().split()))

In [None]:
df.head()

In [None]:
sns.displot(data=df, x='length', palette='mako', kind='kde', fill=True, aspect=4)

Observing that, almost sentence length is <= 30. So I think 30 is good choice as `max_length`

In [None]:
captions = df["caption_text"].tolist()

In [None]:
captions[:10]

Do some statistic to count the occurence frequency of words in our captions

In [None]:
word_freq = {}
for caption in captions:
    caption = caption.strip()
    for word in caption.split():
        if word not in word_freq:
            word_freq[word] = 0
        word_freq[word] += 1

See the top 30 words appear the most and the least

In [None]:
dict(sorted(word_freq.items(), key=lambda item: item[1])[:30])

In [None]:
dict(sorted(word_freq.items(), key=lambda item: item[1], reverse=True)[:30])

#### 2. Data preparation

By having preprocessed captions, we start to build vocabulary in our dataset and convert string to token

In [None]:
def build_vocab(captions, word_freq, count_threshold=5):
    """
    This function builds `vocab` dictionary from list of text captions.
    Also, add constant PAD, UNK, START, END to `vocab`.
    Add a word to vocab if its occurence frequency is larger than `count_threshold`
    
    Parameters
    ----------
    captions: a list of preprocessed text captions above.
    word_freq: a dictionary of word occurence frequency.
    count_threshold: a int to use when building vocab.

    Returns
    -------
    vocab: an dictionary vocabulary of key-value pair which is:
        -> key: string text
        -> value:  token index
    inv_vocab: an inverse dictionary vocabulary of key-value pair which is:
        -> key: token index
        -> value: string text
        
    E.g: vocab = {"two": 4, "young": 5, "guys": 6, ...} 
         inv_vocab = {4: "two", 5: "young", 6: "guys", ...}
    """
    vocab = {
        PAD: 0,
        UNK: 1,
        START: 2,
        END: 3
    }
    index = 4
    
    for caption in captions:
        caption = caption.strip().split(" ")
        for word in caption:
            if word and word_freq[word] >= count_threshold and word not in vocab:
                vocab[word] = index
                index += 1

    inv_vocab = {v: k for k, v in vocab.items()}
    return vocab, inv_vocab

In [None]:
vocab, inv_vocab = build_vocab(captions, word_freq)

In [None]:
def convert_captions(captions, vocab, max_length=30):
    """
    Convert text captions to index token based on `vocab`.
    If a word not in vocab, replace it by the token index of `UNK` constant.
    Also, add `START` constant to the beginning of the sentence and 
            `END` constant to the end of the sentence.
    After add `START` and `END` constant, if the length is still < 30,
        use `PAD` constant to fill remaining positions.
        
    Parameters
    ----------
    captions: a list of preprocessed text captions above.
    vocab: a dictionary vocabulary of key-value pair which is:
        -> key: string text
        -> value: token index
    max_length: an int denotes fixed maximum length to the captions.
    
    Returns
    -------
    tokens: a list of tokens get from `vocab`
    """
    tokens = [[vocab[PAD]]*max_length for _ in range(len(captions))]
    for i, caption in enumerate(captions):
        caption = caption.strip().split()
        tokens[i][0] = vocab[START]
        j = 1
        for word in caption[:max_length-2]:
            if word not in vocab:
                tokens[i][j] = vocab[UNK]
            else:
                tokens[i][j] = vocab[word]
            j += 1
        tokens[i][j] = vocab[END]
    return tokens

In [None]:
tokens = convert_captions(captions, vocab)
img_paths = list(df["image_name"])

Define pytorch Dataset class

In [None]:
class ImageCaptioningDataset(torch.utils.data.Dataset):
    
    def __init__(self, img_paths, tokens):
        """
        img_paths: a list of image path we get from dataframe
        tokens: a list of tokens that we converted from text captions
        """
        self.img_paths = [os.path.join(INPUT_IMAGES_DIR, p) for p in img_paths]
        self.tokens = tokens
        assert len(self.img_paths) == len(self.tokens), "Make sure len(img_paths) == len(tokens)."
    
    def __getitem__(self, index):
        """
        Get image path and token. Then load image path to numpy array image. Convert to pytorch tensor if it's necessary. 
        """
        img_path = self.img_paths[index]
        token = self.tokens[index]
        img = cv2.imread(img_path)
        img = self._resize_img(img, shape=(300, 300))
        img = torchvision.transforms.ToTensor()(img)
        token = torch.as_tensor(token)
        return img, token
    
    def __len__(self):
        return len(self.img_paths)

    def _resize_img(self, img, shape=(300, 300)):
        h, w = img.shape[0], img.shape[1]
        pad_left = 0
        pad_right = 0
        pad_top = 0
        pad_bottom = 0
        if h > w:
            diff = h - w
            pad_top = diff - diff // 2
            pad_bottom = diff // 2
        else:
            diff = w - h
            pad_left = diff - diff // 2
            pad_right = diff // 2
        cropped_img = img[pad_top:h-pad_bottom, pad_left:w-pad_right, :]
        cropped_img = cv2.resize(cropped_img, shape)
        return cropped_img

In [None]:
dataset = ImageCaptioningDataset(img_paths, tokens)

#### 3. Model architecture

A picture is better than thoudsand words:
![Image Captioning Model](https://raw.githubusercontent.com/yunjey/pytorch-tutorial/master/tutorials/03-advanced/image_captioning/png/model.png)

In this architecture, the encoder is CNN which outputs a feature vector. Then, the decoder RNN (LSTM) uses this feature vector as initial hidden states.

In [None]:
MAX_LENGTH = 30
NUM_VOCAB = len(vocab)
BATCH_SIZE = 128
EPOCH = 5
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"

##### 3.1 Define CNN encoder class:

The best practice is to use pretrained models from ImageNet: VGG, Resnet, Alexnet, Googlenet,... We can call those pretrained models are the backbones.

In [None]:
class CNNEncoder(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.cnn = torchvision.models.resnet34(pretrained=True)

    def forward(self, img):
        return self.cnn(img)


##### 3.2 Define LSTM decoder class:

In this class, you should have to define nn.Embedding, nn.LSTM, nn.Linear,... to appropriate training model.

In [None]:
class RNNDecoder(nn.Module):

    def __init__(self, num_vocab) -> None:
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        self.num_vocab = num_vocab
        self.embedding = nn.Embedding(num_embeddings=num_vocab, embedding_dim=256, padding_idx=0)
        self.num_layers = 1
        self.bidirectional = False
        self.rnn = nn.LSTM(input_size=256, hidden_size=256, num_layers=self.num_layers, batch_first=False, bidirectional=self.bidirectional)
        self.classifier = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, num_vocab)
        )
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input, img_embeded, prediction=False):
        img_embeded = self.bottleneck(img_embeded)
        img_embeded = torch.stack([img_embeded]*(self.num_layers), dim=0)
        if prediction:
            output = []
            hidden = (img_embeded, img_embeded)
            out = input
            while out != vocab[END] and len(output) <= MAX_LENGTH:
                out = torch.tensor([[out]]).to(DEVICE)
                out = self.embedding(out)
                out = out.permute(1, 0, 2)
                out, hidden = self.rnn(out, hidden)
                out = out.permute(1, 0, 2)
                out = self.classifier(out)
                out = self.softmax(out)
                out = torch.argmax(out, dim=-1)
                out = out.squeeze().item()
                output.append(out)
        else:
            input = self.embedding(input)
            input = input.permute(1, 0, 2)
            output, (h, c) = self.rnn(input, (img_embeded, img_embeded))
            output = output.permute(1, 0, 2)
            output = self.classifier(output)
        return output

#### 4. Train model

In this part, you should combine what you defined to train model (Dataset, Encoder, Decoder,...)

In [None]:
class ImageCaptioningModel:

    def __init__(self, encoder : CNNEncoder, decoder : RNNDecoder, train_dataset : ImageCaptioningDataset):
        self.encoder = encoder.to(DEVICE)
        self.encoder.eval()
        self.decoder = decoder.to(DEVICE)
        self.train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        self.optimizer = optim.Adam(decoder.parameters())
        self.loss = nn.CrossEntropyLoss()

    def predict(self, img):
        with torch.no_grad():
            img_embed = self.encoder(img)
            caption = vocab[START]
            caption = self.decoder(caption, img_embed, prediction=True)
        
        text = [inv_vocab[t] for t in caption]
        text = " ".join(text)
        return text
    
    def train(self):
        for e in range(EPOCH):
            pbar = tqdm(self.train_dataloader, desc="Epoch: {}".format(e+1))
            for i, (img, caption) in enumerate(pbar):
                img = img.to(DEVICE)
                caption = caption.to(DEVICE)
                img_embed = self.encoder(img)
                output = self.decoder(caption[:, :-1], img_embed)
                output = output.permute(0, 2, 1)
                loss = self.loss(output, caption[:, 1:])

                self.optimizer.zero_grad()
                loss.backward() 
                self.optimizer.step()

                pbar.set_description(desc="Epoch " + str(e+1) + " - Loss: %.5f" % (loss.item()))
                
                if ((i+1)%100) == 0:
                    plt.imshow(img[-1].cpu().detach().numpy().transpose((1, 2, 0)))
                    output = self.predict(img[-1].unsqueeze(0))
                    plt.title(output)
                    plt.show()

In [None]:
cnn = CNNEncoder()
rnn = RNNDecoder(num_vocab=NUM_VOCAB)
model = ImageCaptioningModel(encoder=cnn, decoder=rnn, train_dataset=dataset)

In [None]:
def resize_img(self, img, shape=(300, 300)):
        h, w = img.shape[0], img.shape[1]
        pad_left = 0
        pad_right = 0
        pad_top = 0
        pad_bottom = 0
        if h > w:
            diff = h - w
            pad_top = diff - diff // 2
            pad_bottom = diff // 2
        else:
            diff = w - h
            pad_left = diff - diff // 2
            pad_right = diff // 2
        cropped_img = img[pad_top:h-pad_bottom, pad_left:w-pad_right, :]
        cropped_img = cv2.resize(cropped_img, shape)
        return cropped_img

img_path = "../datasets/flickr30k/test_examples/boat.png"
img = cv2.imread(img_path)
img = resize_img(img, shape=(300, 300))
img = torchvision.transforms.ToTensor()(img)

In [None]:
model.train()

#### 5. Predict

In [None]:
# TODO 7: predict on images after you trained model....