📖 **Resources**:
 * COCO - https://cocodataset.org/
 * EfficientNet B6 - https://pytorch.org/vision/stable/generated/torchvision.models.efficientnet_b6.html#torchvision.models.efficientnet_b6
 * T5 Model HF - https://huggingface.co/docs/transformers/model_doc/t5#t5#
 * T5 Small HF - https://huggingface.co/t5-small
 * Aladdin Persson's image captioning repo https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/more_advanced/image_captioning
 * Guide to image captioning - https://towardsdatascience.com/a-guide-to-image-captioning-e9fd5517f350
 * Andrej Karpathy's Deep Visual-Semantic Alignments for Generating Image Descriptions - https://cs.stanford.edu/people/karpathy/deepimagesent/
 * Andrej Karpathy's minGPT implementation https://github.com/karpathy/minGPT
 * Aladdin Persson's transformer from scratch - https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/more_advanced/transformer_from_scratch
 * The role of the mask in transformer's decoder - https://ai.stackexchange.com/questions/23889/what-is-the-purpose-of-decoder-mask-triangular-mask-in-transformer
 * Semi-Autoregressive Transformer for Image Captioning (Relaxed mask idea) - https://arxiv.org/pdf/2106.09436.pdf
 * Transformer guide (training + evaluating) https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec

🔑 **Note**: Images from Coco will be downloaded, because model will be trained using GoogleColab, drive if which allow us to upload ~15 GB of data. Coco dataset is larger.

## 1. Import depedencies

In [1]:
import os
import json
import time
import string
import itertools
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import skimage.io as io
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from skimage.transform import resize
from torchvision import models, transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass

## 2. Data expectation

In [2]:
# Load data for expectations
with open(os.path.join('data', 'annotations', 'captions_train2017.json'), 'r') as f:
    data = json.loads(f.read())

In [3]:
data['licenses']

[{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/',
  'id': 1,
  'name': 'Attribution-NonCommercial-ShareAlike License'},
 {'url': 'http://creativecommons.org/licenses/by-nc/2.0/',
  'id': 2,
  'name': 'Attribution-NonCommercial License'},
 {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/',
  'id': 3,
  'name': 'Attribution-NonCommercial-NoDerivs License'},
 {'url': 'http://creativecommons.org/licenses/by/2.0/',
  'id': 4,
  'name': 'Attribution License'},
 {'url': 'http://creativecommons.org/licenses/by-sa/2.0/',
  'id': 5,
  'name': 'Attribution-ShareAlike License'},
 {'url': 'http://creativecommons.org/licenses/by-nd/2.0/',
  'id': 6,
  'name': 'Attribution-NoDerivs License'},
 {'url': 'http://flickr.com/commons/usage/',
  'id': 7,
  'name': 'No known copyright restrictions'},
 {'url': 'http://www.usa.gov/copyright.shtml',
  'id': 8,
  'name': 'United States Government Work'}]

In [4]:
data['info']

{'description': 'COCO 2017 Dataset',
 'url': 'http://cocodataset.org',
 'version': '1.0',
 'year': 2017,
 'contributor': 'COCO Consortium',
 'date_created': '2017/09/01'}

In [5]:
pd.DataFrame(data['images'])

Unnamed: 0,license,file_name,coco_url,height,width,date_captured,flickr_url,id
0,3,000000391895.jpg,http://images.cocodataset.org/train2017/000000...,360,640,2013-11-14 11:18:45,http://farm9.staticflickr.com/8186/8119368305_...,391895
1,4,000000522418.jpg,http://images.cocodataset.org/train2017/000000...,480,640,2013-11-14 11:38:44,http://farm1.staticflickr.com/1/127244861_ab0c...,522418
2,3,000000184613.jpg,http://images.cocodataset.org/train2017/000000...,336,500,2013-11-14 12:36:29,http://farm3.staticflickr.com/2169/2118578392_...,184613
3,3,000000318219.jpg,http://images.cocodataset.org/train2017/000000...,640,556,2013-11-14 13:02:53,http://farm5.staticflickr.com/4125/5094763076_...,318219
4,3,000000554625.jpg,http://images.cocodataset.org/train2017/000000...,640,426,2013-11-14 16:03:19,http://farm5.staticflickr.com/4086/5094162993_...,554625
...,...,...,...,...,...,...,...,...
118282,1,000000444010.jpg,http://images.cocodataset.org/train2017/000000...,480,640,2013-11-25 14:46:11,http://farm4.staticflickr.com/3697/9303670993_...,444010
118283,3,000000565004.jpg,http://images.cocodataset.org/train2017/000000...,427,640,2013-11-25 19:59:30,http://farm2.staticflickr.com/1278/4677568591_...,565004
118284,3,000000516168.jpg,http://images.cocodataset.org/train2017/000000...,480,640,2013-11-25 21:03:34,http://farm3.staticflickr.com/2379/2293730995_...,516168
118285,4,000000547503.jpg,http://images.cocodataset.org/train2017/000000...,375,500,2013-11-25 21:20:21,http://farm1.staticflickr.com/178/423174638_1c...,547503


In [6]:
pd.DataFrame(data['annotations'])

Unnamed: 0,image_id,id,caption
0,203564,37,A bicycle replica with a clock as the front wh...
1,322141,49,A room with blue walls and a white sink and door.
2,16977,89,A car that seems to be parked illegally behind...
3,106140,98,A large passenger airplane flying through the ...
4,106140,101,There is a GOL plane taking off in a partly cl...
...,...,...,...
591748,133071,829655,a slice of bread is covered with a sour cream ...
591749,410182,829658,A long plate hold some fries with some sliders...
591750,180285,829665,Two women sit and pose with stuffed animals.
591751,133071,829693,White Plate with a lot of guacamole and an ext...


🔑 **Note**: As we can see we don't need every column from data, so we can drop some. In fact we can only use annotations and generate url for img using id.
🔑 **Note**: Dataset will contain images and five captions for each, because downloading each images five times takes too much time during training.

In [7]:
del data['info']
del data['licenses']
del data['images']

data = pd.DataFrame(data['annotations']).drop('id', axis=1)

In [8]:
data

Unnamed: 0,image_id,caption
0,203564,A bicycle replica with a clock as the front wh...
1,322141,A room with blue walls and a white sink and door.
2,16977,A car that seems to be parked illegally behind...
3,106140,A large passenger airplane flying through the ...
4,106140,There is a GOL plane taking off in a partly cl...
...,...,...
591748,133071,a slice of bread is covered with a sour cream ...
591749,410182,A long plate hold some fries with some sliders...
591750,180285,Two women sit and pose with stuffed animals.
591751,133071,White Plate with a lot of guacamole and an ext...


**Comparision number of samples with limited and unlimited length and their ratio**

In [9]:
lens = data.apply(lambda x: len(x['caption'].split(' ')), axis=1)

In [10]:
np.percentile(lens, 98.5)

18.0

In [11]:
len(lens[lens < 18]), len(lens), len(lens[lens < 18]) / len(lens)

(581751, 591753, 0.9830976775783139)

In [12]:
del data

In [13]:
@dataclass
class cfg:
    epochs = 5
    batch_size = 12 # 5*12 captions
    lr = 1e-5
    max_len = 18
    width = 489
    height = 456
    embed_size = 128
    num_layers = 6
    num_heads = 8
    forward_expansion = 4
    dropout = 0.05

In [14]:
class Vocabulary:
    def __init__(self):
        
        self.vocab = {
            '<unk>': 0,
            '<pad>': 1,
            '<sos>': 2,
            '<eos>': 3
        }
        
    def __getitem__(self, index):
        assert type(index) in [str, int], 'Index type must be string or int'
        
        if isinstance(index, str):
            try:
                return self.vocab[index]
            
            except KeyError:
                return self.vocab['<unk>']
        
        elif isinstance(index, int):
            try:
                return list(self.vocab.keys())[list(self.vocab.values()).index(index)]
            except (KeyError,ValueError):
                return self[0]
    
    def __len__(self):
        return len(self.vocab)
    
    def append_word(self, word):
        if not word in self.vocab:
            self.vocab[word] = len(self)
    
    def build_vocab(self, data):
        """
            Takes array-like object.
        """
        
        for _ in range(2):
            data = list(itertools.chain.from_iterable(data))
        
        bag_of_words = sorted(list(set(data)))
        
        for word in bag_of_words:
            self.append_word(word)

In [15]:
class CocoCaptions(Dataset):
    def __init__(self, path):
        """
            path: train/valid
        """
        
        with open(os.path.join('data', 'annotations', f'captions_{path}2017.json'), 'r') as f:
            self._data = json.loads(f.read())
        
        self.path = path
        
        self._data = pd.DataFrame(self._data['annotations']).drop('id', axis=1)
        
        self.split_data()
        
        self.preprocessing()
        
        self.vocab = Vocabulary()

        self.vocab.build_vocab(self._data.values)
        
    def __len__(self):
        return len(self._data)
        
    def __getitem__(self, index):
        data = self._data.iloc[index]

        url = self.id_to_url(int(data.name))
        
        img = io.imread(url)
        img = resize(img, (cfg.height, cfg.width), anti_aliasing=True)
        
        caps = []
        for cap in data.values:
            caps.append([self.vocab[word] for word in cap])
        
        return img, caps
        
    def split_data(self):
        ids = self._data['image_id'].value_counts() 
        d = []
        
        for ind in list(set(list(ids[ids == 5].index))):
            s = self._data[self._data['image_id'] == ind]['caption'].tolist()
            d.append([ind, *s])
        
        df = pd.DataFrame(d, columns=['image_id', *[f'cap-{i}' for i in range(5)]]).set_index('image_id')
        
        self._data = df
        
    def preprocessing(self):
        # 5 because we have 5 captions per image 
        for i in range(5):
                                                            # lowercase -> remove punctuations -> split 
            self._data[f'cap-{i}'] = self._data[f'cap-{i}'].apply(lambda x: x.lower().translate(str.maketrans('', '', string.punctuation)).split(' '))
        
    def id_to_url(self, id_):
        length_id = 12
        id_ = str(id_)
        id_str = ''.join(['0' for _ in range(length_id - len(id_))]) + id_

        return f'http://images.cocodataset.org/{self.path}2017/{id_str}.jpg'

In [16]:
data = CocoCaptions('val')
vocab = data.vocab

In [17]:
# Below code to adjust for new dataset

In [18]:
def pad_seq(batch):
    
    imgs = []
    captions = []
    
    for img, caps in batch:
        
        caps_pp = []
        for cap in caps:
            cap = [vocab['<sos>'], *cap]
            text_len = len(cap)

            if text_len >= (cfg.max_len + 1):
                cap = cap[:cfg.max_len+1]
                cap[-1] = vocab['<eos>']    

            else:
                cap.append(vocab['<eos>'])
                pad_len = cfg.max_len - (text_len)
                
                for i in range(pad_len):
                    cap.append(data.vocab['<pad>'])
                    
            caps_pp.append(torch.IntTensor(cap))
            
        captions.append(torch.stack(caps_pp).type(torch.int64))
        imgs.append(torch.Tensor(img))

    return torch.stack(imgs), torch.stack(captions)

In [19]:
train_loader = DataLoader(data, batch_size=2, collate_fn=pad_seq, shuffle=True)

In [20]:
# for i in train_loader:
#     plt.imshow(i[0][0])
#     plt.show()
    
#     print('--CAPTIONS--')
#     for j in i[1][0]:
#         print(' '.join([vocab[word.item()] for word in j]))
#     break

In [21]:
class Classifier(nn.Module):
    def __init__(self, embed_size):
        super(Classifier, self).__init__()
        
        self.fc = nn.Linear(2048, 32*20)
        self.c1 = nn.Conv1d(32, embed_size, kernel_size=3, stride=2, padding=6)
        self.c2 = nn.Conv1d(embed_size, embed_size, kernel_size=2, stride=1, padding=2)
        
    def forward(self, x):
        # x shape: N, 2048
        x = F.gelu(self.fc(x))
        
        x = x.view(-1, 32, 20)

#         x = x.permute(0, 2, 1)
        x = F.gelu(self.c1(x))
        
        x = self.c2(x)
        return x.permute(0, 2, 1)

In [22]:
class EffNet(nn.Module):
    def __init__(self, max_length):
        super(EffNet, self).__init__()
        
        self.eff = models.efficientnet_b5(pretrained=True, progress=True)
        self.eff.classifier = Classifier(128) # nn.Linear(2048, max_length)
        
        self._freeze_layers()
        
    def _freeze_layers(self):
        for mod in self.eff.features[:-2].parameters():
            mod.requires_grad = False
    
    def forward(self, img):
        x = self.eff(img)
        return x

In [23]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert self.head_dim * heads == embed_size, 'Embed size needs to be divisible by heads'
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)
        
    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        
        values = self.values(values) 
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        energy = torch.einsum('nqhd, nkhd->nhqk', [queries, keys])
        
        if mask is not None:
            # if mask at same point is 0 - shitdown this point - set to -inf, in softmax it will be 0
            energy = energy.masked_fill(mask == 0, -1e20)
        
        attention = torch.softmax(energy / (self.embed_size**(1/2)), dim=3)
        
        # attention shape: N, heads, query_len, key_len
        # values shape: N, value_len, heads, head_dim
        # out shape: N, query_len, heads, head_dim
        out = torch.einsum('nhql, nlhd->nqhd', [attention, values])
        
        out = out.reshape(N, query_len, self.heads*self.head_dim)
        
        return out

In [24]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, value, key, queries, mask):
        attention = self.attention(value, key, queries, mask)
        
        x = self.dropout(self.norm1(attention + queries))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        
        return out

In [25]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention = SelfAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, None)
        
        return out

In [26]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length,
    ):
        
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
        self._init_weights()
        
    def _init_weights(self):
        for module in self.named_parameters():
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
                    
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
                
    def forward(self, x, enc_out, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))
        
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, mask)

        out = self.fc_out(x)

        return out

In [27]:
class ImageCaptioner(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len, num_layers, num_heads, forward_expansion, dropout, device):
        super(ImageCaptioner, self).__init__()
        
        self.eff_enc = EffNet(max_len)
        self.trans_dec = Decoder(
            vocab_size=vocab_size,
            embed_size=embed_size,
            num_layers=num_layers,
            heads=num_heads,
            forward_expansion=forward_expansion,
            dropout=dropout,
            device=device,
            max_length=max_len
        )
    
    def get_num_params(self):
        return sum(par.numel() for par in self.parameters())
    
    def make_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask.to(device)

    def forward(self, img, trg):
        enc_out = self.eff_enc(img.permute(0, 3, 2, 1))
        
        dec_out = self.trans_dec(trg, enc_out, self.make_mask(trg))
        
        return dec_out

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [29]:
model = ImageCaptioner(
            vocab_size=len(vocab),
            embed_size=cfg.embed_size,
            num_layers=cfg.num_layers,
            num_heads=cfg.num_heads,
            forward_expansion=cfg.forward_expansion,
            dropout=cfg.dropout,
            device=device,
            max_len=cfg.max_len
          )

In [30]:
optimizer = optim.AdamW(model.parameters(), lr=cfg.lr)
criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])

In [31]:
def train_epoch(model, optimizer, criterion, loader):
    running_loss = []

    t0 = time.time()
    for b_idx, (img, trgs) in enumerate(loader):
        img = img.to(device)
        trgs = trgs.to(device)
        
        for trg in trgs.permute(1, 0, 2):
            plt.imshow(img[0])
            
            scores = model(img, trg[:, :-1])

            optimizer.zero_grad()
            loss = criterion(scores.permute(0, 2, 1), trg[:, 1:])

            loss.backward()
            optimizer.step()

            running_loss.append(loss.item())

    loss = sum(running_loss) / len(running_loss)

    print(f'Loss: {loss}, time: {time.time() - t0}s')
    return sum(running_loss) / len(running_loss)

In [32]:
model = model.to(device)

loss = []

for epoch in range(cfg.epochs):
    pass
    loss.append(train_epoch(model, optimizer, criterion, train_loader))
    
plt.plot(loss)
plt.show()

torch.save(model.state_dict(), os.path.join('drive', 'MyDrive', 'Colab Notebooks', 'models', 'captioner-1.pt'))