# Small Video Captioning transformer

 Video captioning (describes the context of a video) using a processing efficient architecture.
 - small CNN or VisionTransformer as selectable backbone
 - encoder - decoder with text and vision information that generates in autoregressive manner the image captioning

Author: fvilmos
https://github.com/fvilmos

See references:
- Attention Is All You Need - https://arxiv.org/abs/1706.03762

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import glob
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms
import argparse
import os
from PIL import Image
from thop import profile

from utils.vocabulary import Vocabulary
from utils.msvd_dataset import MsvdDataset
from utils.video_captioning_transformer import VideoCaptioner, generate_caption
from utils.video_encoder import CNNEncoder, ViTEncoder, MobileNetV2Encoder


In [None]:
BATCH_SIZE=128
IN_CHANNELS = 3
IMG_SIZE = 224
EPOCS = 20
LR = 3e-4
WEIGHT_DECAY=1e-4
LOG_STEP = 20
MAX_LEN= 100
MAX_GEN = 50
NR_OF_FRAMES = 8
VISION_Encoder = MobileNetV2Encoder
WORKERS=0

save_path = "./vct_model.pth"
root_dir = '../../datasets/msvd/YouTubeClips'
annotation_file = '../../datasets/msvd/annotations.txt'
voc_path = './voc.json'
model_path = "./vct_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print ("Run on:", device)

### Build Vocabulary

In [None]:
def get_ann_dict(fname):
    ann_dict = {}
    with open(fname, 'r') as file:
        lines = file.readlines()
        for i,l in enumerate(lines):
            # slice line, get ID, rest dump as a caption, 
            # ignore the seme key if appears again
            sl = l.strip().split(' ')
            id = sl[0]
            cap = sl[1:]
            if id not in ann_dict:
                ann_dict[i] = {'id':id, 'caption':cap}
    return ann_dict
    
ann_dict = get_ann_dict(annotation_file)

print (ann_dict[0], len(ann_dict))

In [None]:
def build_vocab(ann_dict, threshold=3):
    vocab = Vocabulary()
    counter = {}
    for k,v in ann_dict.items():
        caption = " ".join(v['caption'])
        tokens = vocab.custom_word_tokenize(caption)
        for word in tokens:
            counter[word] = counter.get(word, 0) + 1

    words = [word for word, cnt in counter.items() if cnt >= threshold]

    for word in words:
        vocab.add_word(word)
    return vocab

voc = build_vocab(ann_dict=ann_dict, threshold=4)
print ("voc len ==>",len(voc))

# export voc
voc.export_vocabulary("./voc.json")

print ('vocabulary exported!')

### Dataloader

In [None]:
def collate_fn(data):
    """
    Creates a mini-batch of tensors from a list of tuples.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    videos, captions = zip(*data)

    # Merge videos (from tuple of 3D tensor to 4D tensor).
    videos = torch.stack(videos, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return videos, targets, lengths

In [None]:
# create data loader
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
dataset = MsvdDataset(root=root_dir, annotation_dict=ann_dict, vocab=voc, transform=transform, num_frames=NR_OF_FRAMES)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS, collate_fn=collate_fn)

videos, targets, lenghts = next(iter(data_loader))

In [None]:
print (videos.shape, targets.shape, len(lenghts))

### Create model

In [None]:
model = VideoCaptioner(vocab_size=len(voc),
                       dim=256,
                       num_heads=4,
                       num_layers=2,
                       vis_out_dimension=1280,
                       num_frames=NR_OF_FRAMES,
                       max_len=MAX_LEN,
                       vis_hxw_out = 49,
                       VisionEncoder=VISION_Encoder).to(device)


optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=LR,
                              weight_decay=WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

### Model Complexity

In [None]:
voc_size = len(voc)

# B, NUM_FRAMES, C, H, W
dummy_images = torch.randn(1,8, 3, IMG_SIZE, IMG_SIZE).to(device)

# B, Nmax-1
dummy_captions = torch.randint(0, voc_size, (1, MAX_LEN - 1)).long().to(device)

# Profile the model using thop
macs, params = profile(model, inputs=(dummy_images, dummy_captions))

print(f"Total MACs: {macs / 1e9:.2f} G, FLOPS: {2*macs / 1e9:.2f} G")
print(f"Total Parameters: {params / 1e6:.2f} M")

### Train

In [None]:
best_loss = 1e+9

# save the model
def save_model(in_model,path):
    torch.save(in_model.state_dict(), path)

# Training loop
for epoch in range(EPOCS):
    for i, (videos, captions, lengths) in enumerate(data_loader):
        model.train()
        model.vision_encoder.vision_encoder.eval()

        # Move data to device
        videos = videos.to(device)
        captions = captions.to(device)
        targets = captions[:, 1:].to(device)

        # Forward pass
        outputs = model(videos, captions[:, :-1])

        # Calculate loss
        loss = torch.nn.functional.cross_entropy(outputs.reshape(-1, len(voc)), targets.reshape(-1), ignore_index=voc('<pad>'))

        if best_loss > loss.item():
            save_model(model,save_path)
            best_loss = loss.item()
        
        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log progress
        if i % LOG_STEP == 0:
            print(f'Epoch [{epoch + 1}/{EPOCS}], Step [{i}/{len(data_loader)}],',
                  f'Loss: {loss.item():.4f}, Perplexity: {torch.exp(loss).item():.4f}')

    scheduler.step()

## Test Model


In [None]:
test_model = VideoCaptioner(vocab_size=len(voc),
                            dim=256,
                            num_heads=4,
                            num_layers=2,
                            vis_out_dimension=1280,
                            vis_hxw_out=49,
                            num_frames=NR_OF_FRAMES,
                            max_len=MAX_LEN,
                            VisionEncoder=VISION_Encoder).to(device)

# load the trained model
test_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)

# get a list of test videos
test_video_paths = glob.glob(os.path.join(root_dir, '*.avi'))

# generate captions for a few sample videos
for video_path in test_video_paths[:5]:
    print("video path:", video_path)
    dataset = MsvdDataset(root=root_dir, annotation_dict=ann_dict, vocab=voc, transform=transform, num_frames=NR_OF_FRAMES)
    video = dataset._load_frames(video_path)
    tvideo = torch.stack([transform(frame) for frame in video]).to(device)

    # Generate caption
    cap = generate_caption(test_model,tvideo, voc, max_len=MAX_GEN)
    print("Caption:", " ".join(cap[0]))

    plt.imshow(video[5])
    plt.axis('off')
    plt.title(" ".join(cap[0][1:-1]))
    plt.show()