# NeuralTalk

* Deep Semantic Visual Embeddings for Image Captioning - Andrej Karpathy, Fei Fei Li.

The following notebook glues together the components, which are defined across several files in the folder to a linear story which reasonably reproduces the results of the paper.

I'm not going to be particular about the details - for example, I'll use and LSTM instead of an RNN, put in Linear layers in between for changing sizes to experiment how the number of parameters affect the results.


In [1]:
#!/usr/bin/python3
from argparse import ArgumentParser
import json
from pprint import pprint
from itertools import chain
from tqdm import tqdm_notebook as tqdm
import matplotlib
#matplotlib.use('Agg')
from matplotlib import pyplot as plt
import os,sys
import string
import torch
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import datetime
from tensorboardX import SummaryWriter
import numpy as np

## Preprocessing

The tokens are already available preprocessed in the `.json` file, we're going to reuse them. Thus we require only parsing. 

A Flickr DataLoader is adapted from [here](https://github.com/fartashf/vsepp/blob/master/data.py). It supports batching together variable length targets with padding through a collate function, which is the general practice in pytorch.

The Dataset class requires an additional `Vocab` object which is function which returns a unique index corresponding to a unique token. We'll code our own up and keep two things in mind:
1. `vocab(token)` gives id, no error - for unknown, give a token corresponding to unknown.
2. `len(vocab)` gives the total number of words in the vocabulary.

In [2]:
uspath = '/home/jerin/code/ultimate-sandbox'
flickr_root = '/tmp/Flickr-8K/Flicker8k_Dataset'
sys.path.insert(0, uspath)
from preproc import Vocab, img_preprocess, extract_tokens
from usandbox.data import FlickrDataset

In [3]:
# Construct vocabulary
json_file = 'dataset_flickr8k.json'
tokens = extract_tokens(json_file)
vocab = Vocab(tokens)
vocab('you')

1919

In [4]:
def get_dataset(name):
    return FlickrDataset(flickr_root, json_file, name, 
                        vocab, transform=img_preprocess)

dataset = {}
for phase in ['train', 'test', 'val']:
    dataset[phase] = get_dataset(phase)


In [5]:
!rsync -rz --info=progress2 ada:/share1/dataset/Flickr-8K/ /tmp/Flickr-8K/ --append

              0   0%    0.00kB/s    0:00:00 (xfr#0, to-chk=0/16197)   


In [6]:
batch_params = {
    "batch_size": 512,
    "shuffle"  : True,
    "num_workers" : 20,
    "collate_fn": FlickrDataset.collate_fn
}

def n_batches(dataset):
    return len(dataset)//batch_params["batch_size"] + 1

def loader(dataset):
    return torch.utils.data.DataLoader(dataset=dataset, **batch_params)

## Models

Models which are used in composition are defined in `models.py`. Some are tiny enough to be declared at a later stage, from torch's own predefined building blocks. The models are constituted by:





### 1. CNN Feature Extractor
A modified `Resnet18`, which I call `ResnetMinus`, since it lacks the last softmax layer thereby giving be dense features representing the image is used to supply the hidden representation for a generative RNN model, which predicts the captions. 

Learning how to use the pretrained available `Resnet18` and dropping layers, freezing the parameters so that the gradients are not computed while training for captioning were first time hands on learning for me. Turns out pytorch only requires you to iterate through `model.parameters()` and turn `requires_grad=False` for them, so they're frozen.

In [7]:
import models

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, embedding):
        super().__init__()
        self.embed = embedding
        self.resnet = models.ResnetMinus()
    
    def forward(self, x):
        r = self.resnet(self.embed(x))
        B, H, _, _ = r.size()
        r = r.reshape(B, H)
        return r

### 2. Generative RNN
We train an RNN (I use an LSTM variant here) to take in the hidden representation as the image feature given by `ResnetMinus`. The input at the first time step is an index `<start>` token. An additional `nn.Embedding` layer is used to embed the words given by the indices in a dense space. Let's say the caption we are to learn is `z[1:t]` at each time step. 

The sequence we're trying to learn will be: `z[1:(t-1)] -> z[2:t]`.

The values are passed through an output softmax layer and we could do a `Greedy` or `Beam-Search` based decoding to get the captions. We'll stick to `Greedy` in this particular attempt.

In [8]:
class Decoder(nn.Module):
    def __init__(self, params, embedding, interpretor, n_classes):
        super().__init__()
        self.params = params
        self.lstm = nn.LSTM(**params)
        self.embed = embedding
        self.interpreter = interpreter
        self.generator = nn.Linear(params["hidden_size"], n_classes)
        
    def forward(self, context, seed, teacher_forcing=True):
        h = context.repeat(self.params["num_layers"], 1, 1)
        c = torch.zeros_like(h)
        tgt = self.embed(seed)
        y_prev = seed[:, 0:1]
        B, T, H = tgt.size()
        max_length = T
        ys = []
        for t in range(max_length-1):
            yt = tgt[:, t:t+1, :] if teacher_forcing else self.embed(y_prev)
            y, (h, c) = self.lstm(yt, (h, c))
            yt = self.generator(y)
            ys.append(yt)
            y_prev = self.interpreter.argmax(yt.detach())            
        y = torch.cat(ys, dim=1)
        return y    

In [9]:
class Interpreter:
    def __init__(self, vocab):      
        self.softmax = nn.Softmax(dim=2)
        self.build_vocab(vocab)
        
    def build_vocab(self, vocab):
        self.idx2word = {}
        for key, value in vocab.word2idx.items():
            self.idx2word[value] = key
            
    def inverse(self, indices):
        tokens = list(map(lambda x: self.idx2word[x], indices))
        return tokens
    
    def argmax(self, acts):
        probs = self.softmax(acts)
        max_value, max_index = probs.max(dim=2)
        return max_index
    
    def decode(self, acts):
        B, T, H = acts.size()
        batch = []
        indices = self.argmax(acts)
        for i in range(B):
            tokens = self.inverse(indices[i, :].tolist())
            ostr = ' '.join(tokens)
            batch.append(ostr)
        return batch

In [10]:
class TCrossEntropy(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, y, z):
        B, T, H = y.size()
        B, T = z.size()
        #y = y.permute(1, 0, 2).contiguous()
        y = y.view(-1, H)
        z = z.contiguous().view(-1)
        return self.criterion(y, z)

In [11]:
input_size, hidden_size = 512, 512
src_embed = models.Identity()
encoder = Encoder(input_size, hidden_size, src_embed)
lparams = {
    "input_size": 100,
    "hidden_size": 512,
    "num_layers": 5,
    "dropout": 0.2,
    "bidirectional": False,
    "batch_first": True
}

tgt_embed = nn.Embedding(len(vocab), lparams["input_size"])
interpreter = Interpreter(vocab)
decoder = Decoder(lparams, tgt_embed, interpreter, len(vocab))
net = models.EncoderDecoder(encoder, decoder)
device = torch.device("cuda:0")
net = net.to(device)
criterion = TCrossEntropy()        
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=0.5)

## Training


In [12]:
from usandbox.stats import Meter    
from usandbox.logs import Logger
from torch.nn.utils import clip_grad_norm_

class Trainer:
    def __init__(self, model, loss, optimizer, dataset, run, decoder=None):
        self.model = model
        self.loss = loss
        self.optimizer = optimizer
        self.logger = {phase: SummaryWriter(log_dir="/tmp/jerin/logs/{}/{}".format(phase, run)) 
                       for phase in ['train', 'val']}
        self.dataset = dataset
        self.best_loss = float("inf")
        
    def run_epochs(self, max_epochs):
        for epoch in tqdm(range(max_epochs), desc='epoch', leave=True):
            self.epoch = epoch
            self.train()
            self.validate()
        
    def train(self):
        self.model.train()
        loss = self.process("train")
    
    def validate(self):
        self.model.eval()
        loss = self.process("val")
        if loss < self.best_loss:
            self.best_loss = loss
            self.best_model = self.export()
    
    def export(self):
        checkpoint = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict()
        }
        return checkpoint
    
    def load(self, checkpoint):
        self.model.load_state_dict(checkpoint["model"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        
    
    def clip_grad(self):
        max_grad_norm = 5
        params = list(filter(lambda x: x.requires_grad, self.model.parameters()))
        clip_grad_norm_(params, max_grad_norm)
    
    def log_decode(self, phase, x, y):
        decodes = self.model.decoder.interpreter.decode(y)
        def invert(img):
            img = img.numpy().transpose((1, 2, 0))
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = std * img + mean
            img = np.clip(img, 0, 1)
            return img 
        
        plt.switch_backend('agg')
        for i, caption in enumerate(decodes):
            for undesirable in ["<start>", "<end>", "</start>"]:
                caption = caption.replace(undesirable, "")
            caption = caption.strip()
            plt.clf()
            figure = plt.figure(figsize=(10, 10))
            plt.title(caption, fontsize=16)
            img = x[i, :, :, :].cpu()
            img = invert(img)
            plt.imshow(img)
            plt.axis("off")
            trainer.logger[phase].add_figure('captions-{}'.format(self.epoch), figure, i)

    
    def process(self, phase):
        meter = Meter()
        dataset = self.dataset[phase]
        for i, b in tqdm(enumerate(loader(dataset)), 
                         total=n_batches(dataset), 
                         desc=phase, 
                         leave=False):
            if phase == "train":
                self.optimizer.zero_grad()

            x, z, *_ = b       
            x = x.to(device)
            z = z.to(device)
            y = net(x, z, teacher_forcing=False)
            loss = self.loss(y, z[:, 1:])
            meter.report(loss.item())

            self.logger[phase].add_scalar('loss', loss.item(), i)
            if phase == "train":
                loss.backward()
                self.clip_grad()
                self.optimizer.step()    
                
        self.log_decode(phase, x, y)
        self.logger[phase].add_scalar('loss/avg'.format(phase), meter.avg(), self.epoch)
        return meter.avg()

In [13]:
run = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
trainer = Trainer(net, criterion, optimizer, dataset, run)
with open("model.weights", "rb") as fp:
    checkpoint = torch.load(fp)
    trainer.load(checkpoint)


In [None]:
trainer.run_epochs(25)

HBox(children=(IntProgress(value=0, description='epoch', max=25), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=59), HTML(value='')))

In [None]:
with open("model.weights", "wb+") as ofp:
    torch.save(trainer.export(), ofp)

In [None]:
y = net(x, z, teacher_forcing=False)

In [None]:
decodes = interpreter.decode(y)

In [None]:


for i, caption in enumerate(decodes):
    for undesirable in ["<start>", "<end>", "</start>"]:
        caption = caption.replace(undesirable, "")
    caption = caption.strip()
    plt.clf()
    figure = plt.figure(figsize=(10, 10))
    plt.title(caption, fontsize=16)
    img = x[i, :, :, :].cpu()
    img = invert(img)
    plt.imshow(img)
    plt.axis("off")
    trainer.logger['train'].add_figure('caption-attempt-7', figure, 0)