# Small Image captioning transformer

 Image captioning (describes the context of an image) using a processing efficient architecture.
 - small CNN or VisionTransformer for image context
 - encoder - decoder with text and vision information that generates in autoregressive manner the image captioning

Author: fvilmos

See references:
- Attention Is All You Need - https://arxiv.org/abs/1706.03762


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import glob
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from thop import profile
from utils.mardown_display import mm

from pycocotools.coco import COCO
from PIL import Image
import os

from utils.vocabulary import Vocabulary
from utils.decoder import ImageCaptioner, generate_caption
from utils.vision_encoder import CNNEncoder, ViTEncoder, MobileNetV2Encoder


In [None]:
BATCH_SIZE=80
IN_CHANNELS = 3
IMG_SIZE = 224
EPOCS = 20
LR = 3e-4
WEIGHT_DECAY=1e-4
LOG_STEP = 100
MAX_LEN= 100
MAX_GEN = 50

save_path = "./sic_model.pth"
root_dir = '../../datasets/_coco/train2017'
annotation_file = '../../datasets/_coco/captions_train2017.json'

# model evaluation
model_path = "./sic_model.pth"
root_dir_val = '../../datasets/_coco/val2017'
annotation_file_val = '../../datasets/_coco/captions_val2017.json'

# model test
root_dir_tst = '../../datasets/_coco/test2017'
annotation_file_tst = '../../datasets/_coco/captions_test2017.json'

workers = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print ("available device ==>", device, "\n")

# test img list
img_list = glob.glob(root_dir_tst + '/*.jpg')
print ("test images:", len(img_list))

In [None]:
# MobileNetV2 building blocks
mm("""
   flowchart LR;
   subgraph imagein [Image processing];
      A(["input\nimage"])--|3xwxh|-->B["VisionEncoder"];
   end;
   subgraph transformer [transformer blocks];
      B---->C["Cross attention"];
   end;
   subgraph textin ["text processing"]
      D["Masked self-attention"]-->C
      E["Word embedding \n positional embedding"]-->D
      G(["start sequence"])-->E
   end
   C --> F["Decoding"]
   """)

# Vocabulary generator

In [None]:
def build_vocab(json_path, threshold=3):
    coco = COCO(json_path)
    vocab = Vocabulary()
    counter = {}
    ids = coco.anns.keys()
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = vocab.custom_word_tokenize(caption)
        for word in tokens:
            counter[word] = counter.get(word, 0) + 1

    words = [word for word, cnt in counter.items() if cnt >= threshold]

    for word in words:
        vocab.add_word(word)
    return vocab

voc = build_vocab(annotation_file, 4)
print ("voc len ==>",len(voc))

# export voc
voc.export_vocabulary("./voc.json")

print ('vocabulary exported!')


## DataLoader

In [None]:
class CocoCaptionDataset(data.Dataset):
    """
    Helper to get COCO data and dataloder
    """
    def __init__(self, root, json, vocab, transform=None):
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        ann_id = self.ids[index]
        caption = self.coco.anns[ann_id]['caption']
        img_id = self.coco.anns[ann_id]['image_id']
        path = self.coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform:
            image = self.transform(image)

        tokens = self.vocab.custom_word_tokenize(str(caption))
        caption = []
        caption.append(self.vocab('<start>'))
        caption.extend([self.vocab(token) for token in tokens])
        caption.append(self.vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)

def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return images, targets, lengths

def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
    coco = CocoCaptionDataset(root=root,
                              json=json,
                              vocab=vocab,
                              transform=transform)
    
    data_loader = torch.utils.data.DataLoader(dataset=coco,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader



transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])


data_loader = get_loader(root=root_dir, json=annotation_file,vocab=voc,transform=transform,batch_size=BATCH_SIZE,shuffle=True, num_workers=0)

## Model and complexity

In [None]:
model = ImageCaptioner(vocab_size=len(voc),
                       dim=256, 
                       num_heads=4,
                       num_layers=2, 
                       vis_out_dimension=1280, 
                       vis_hxw_out=49,
                       max_len = MAX_LEN,
                       VisionEncoder=MobileNetV2Encoder).to(device)

# model = ImageCaptioner(vocab_size=len(voc),
#                        dim=768, 
#                        num_heads=8,
#                        num_layers=4, 
#                        vis_out_dimension=512, 
#                        vis_hxw_out=49,
#                        max_len = MAX_LEN,
#                        VisionEncoder=CNNEncoder).to(device)

# # ViT backbone
# model = ImageCaptioner(vocab_size=len(voc),
#                        dim=768, 
#                        num_heads=8,
#                        num_layers=4, 
#                        vis_out_dimension=768, 
#                        vis_hxw_out=50, # 49 + CLS
#                        max_len = MAX_LEN,
#                        VisionEncoder=ViTEncoder).to(device)


optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-4,
    weight_decay=0.01
)

# get model params and complexity
voc_size = len(voc)
dummy_images = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
dummy_captions = torch.randint(0, voc_size, (1, MAX_LEN - 1)).long().to(device)

# Profile the model using thop
macs, params = profile(model, inputs=(dummy_images, dummy_captions))

print(f"Total MACs: {macs / 1e9:.2f} G, FLOPS: {2*macs / 1e9:.2f} G")
print(f"Total Parameters: {params / 1e6:.2f} M")


## Training loop

In [None]:
best_loss = 1e+9
def save_model(in_model,path):
    torch.save(in_model.state_dict(),path)

for epoch in range(EPOCS):
    for i, (images, captions, lengths) in enumerate(data_loader):
        model.train()
        model.vision.eval()
        
        images = images.to(device)
        captions = captions.to(device)
        targets = captions[:, 1:].to(device)

        outputs = model(images, captions[:, :-1])
        
        loss = torch.nn.functional.cross_entropy(outputs.reshape(-1, len(voc)), targets.reshape(-1), ignore_index=voc('<pad>'))

        if best_loss > loss.item():
            #save model
            save_model(model,save_path)
            best_loss = loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % LOG_STEP == 0:
            print(f'Epoch [{epoch+1}/{EPOCS}], Step [{i}/{len(data_loader)}],', f'Loss: {loss.item():.4f}, Perplexity: {torch.exp(loss).item():.4f}')

    # genrate 3 samples, as a test
    for _ in range(3):
        val = np.random.randint(len(img_list))
        image_o = Image.open(img_list[val]).convert('RGB')
        image = transform(image_o.copy()).to(device)

        cap, att = generate_caption(model,image,voc,MAX_GEN, return_att=True)
        
        # test attention
        # smash all heads information to a single layer
        token_idx = 3
        att = att.mean(dim=1)
        att = att[0,token_idx]

        print ("caption:", cap)
        plt.imshow(image_o)
        plt.show()

    torch.save(model.state_dict(), "sic_model.pth")


# Test captioning

In [None]:
try:
    print ("voc len:", len(voc))
except:
    voc = Vocabulary()
    voc.import_vocabulary('./voc.json')
    print ("voc len:", len(voc))

In [None]:
test_model = model = ImageCaptioner(vocab_size=len(voc),
                       dim=256, 
                       num_heads=4,
                       num_layers=2, 
                       vis_out_dimension=1280, 
                       vis_hxw_out=49,
                       max_len = MAX_LEN,
                       VisionEncoder=MobileNetV2Encoder).to(device)

test_model.load_state_dict(torch.load('./sic_model.pth', map_location=device), strict=False)

transform_nn = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

fig = plt.figure(figsize=(8, 8))
for _ in range(3):
        val = np.random.randint(len(img_list))
        image_o = Image.open(img_list[val]).convert('RGB')
        
        image = transform_nn(image_o.copy()).to(device)
        
        cap = generate_caption(test_model,image,voc,50, return_att=False)
        
        print ("caption:"," ".join(cap[0]))
        plt.imshow(image_o)
        plt.axis('off')
        plt.title(" ".join(cap[0][1:-2]))
        plt.show()
