# Captioning Image

Image captioning is the processes of describing what is happening in the image. Since CNN is not good at keeping the temporal information. The task of the image captioning can be divided into two models the one is image based a which takes features from the image. Another one is a language model which takes the feature from the previous model and generate the description, very similar to the language translation task.



## Downloading the data:

To demonstrate the concept of image captioning, we will be using Flickr8k data. The flickr8k dataset was released by Flickr. Flickr8k has one image and five different captions for the image describing the image in different ways.    you may download this dataset fromFlickr8k image captioning dataset https://forms.illinois.edu/sec/1713398. As an alternative academic torrent can be used to download the dataset for non-commercial purpose. The Flickr8k dataset can be downloaded from academic torrents by clicking on this link. http://academictorrents.com/details/9dea07ba660a722ae1008c4c8afdd303b6f6e53b

> Download and lace daatset in `/data` folder before moving ahead

# Importing Requirements

In [None]:
import io
import os

import nltk
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from tensorboardX import SummaryWriter
from torch.nn.utils.rnn import pack_padded_sequence
nltk.download('punkt')
import itertools

import time
import matplotlib.pyplot as plt

import random

SEED = 1234

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Preprocessing 

It has many steps and each step with helpful comment is given below

In [None]:
##Training images list
train_img_list=[]
with open('data/Flickr8k/Flickr8k_text/Flickr_8k.testImages.txt','r') as f:
    for i in f:
        train_img_list.append(i.strip())

In [None]:
##Test images list
test_img_list=[]
with open('data/Flickr8k/Flickr8k_text/Flickr_8k.testImages.txt','r') as f:
    for i in f:
        test_img_list.append(i.strip())

In [None]:
img_caption=[]
with open('data/Flickr8k/Flickr8k_text/Flickr8k.token.txt','r') as f:
    for i in f:
        img_caption.append(i)


In [None]:
##Store all the captions for each image
annot={}
for i in range(0,len(img_caption),5):
    ann=[]
    t1=img_caption[i].strip()
    for j in range(i,i+5):
        tmp=img_caption[j].strip()
        tmp=tmp.split('\t')
        ann.append([tmp[1].lower()])
    t1=t1.split('\t')
    annot[t1[0].split('#')[0]]=ann


In [None]:
##Caption and Image List
cap_dict={}
for i in range(0,len(img_caption),5):
    tmp=img_caption[i].strip()
    tmp=tmp.split('\t')
    cap_dict[tmp[0].split('#')[0]]=tmp[1].lower()

In [None]:
##Training captions
train_cap_dict={}
for i in train_img_list:
    train_cap_dict[i]=cap_dict[i]

In [None]:
##Test captions
test_cap_dict={}
for i in test_img_list:
    test_cap_dict[i]=cap_dict[i]

In [None]:
##Tokenize train captions
train_token=[]
train_tok=[]
for (j,i) in train_cap_dict.items():
    train_token.append([j,nltk.word_tokenize(i)])
    train_tok.append(nltk.word_tokenize(i))

In [None]:
##Tokenize test captions
test_token=[]
for (j,i) in test_cap_dict.items():
    test_token.append([j,nltk.word_tokenize(i)])

In [None]:
##word_to_id and id_to_word
all_tokens = itertools.chain.from_iterable(train_tok)
word_to_id = {token: idx for idx, token in enumerate(set(all_tokens))}

all_tokens = itertools.chain.from_iterable(train_tok)
id_to_word = [token for idx, token in enumerate(set(all_tokens))]
id_to_word = np.asarray(id_to_word)


In [None]:
##Sort the indices by word frequency

train_token_ids = [[word_to_id[token] for token in x[1]] for x in train_token]
count = np.zeros(id_to_word.shape)
for x in train_token_ids:
    for token in x:
        count[token] += 1
indices = np.argsort(-count)
id_to_word = id_to_word[indices]
count = count[indices]


In [None]:
##Recreate word_to_id based on sorted list
word_to_id = {token: idx for idx, token in enumerate(id_to_word)}

In [None]:
print("Vocabulary size: "+str(len(word_to_id)))

In [None]:
## assign -4 if token doesn't appear in our dictionary
## add +4 to all token ids, we went to reserve id=0 for an unknown token
train_token_ids = [[word_to_id.get(token,-4)+4 for token in x[1]] for x in train_token]
test_token_ids = [[word_to_id.get(token,-4)+4 for token in x[1]] for x in test_token]

In [None]:
word_to_id['<unknown>']=-4
word_to_id['<start>']=-3
word_to_id['<end>']=-2
word_to_id['<pad>']=-1

for (_,i) in word_to_id.items():
    i+=4
    word_to_id[_]=i


# In[18]:


id_to_word_dict={}
cnt=4
for i in id_to_word:
    id_to_word_dict[cnt]=i
    cnt+=1
id_to_word_dict[0]='<unknown>'
id_to_word_dict[1]='<start>'
id_to_word_dict[2]='<end>'
id_to_word_dict[3]='<pad>'

In [None]:
##Length of each caption
train_cap_length={}
for i in train_token:
    train_cap_length[i[0]]=len(i[1])+2
    
test_cap_length={}
for i in test_token:
    test_cap_length[i[0]]=len(i[1])+2

In [None]:
##Add <start> and <end> tokens to each caption
for i in train_token_ids:
    i.insert(0,word_to_id['<start>'])
    i.append(word_to_id['<end>'])

for i in test_token_ids:
    i.insert(0,word_to_id['<start>'])
    i.append(word_to_id['<end>'])


In [None]:
##Pad train captions
length=[]
for (i,j) in train_cap_length.items():
    length.append(j)
max_len=max(length)

for n,i in enumerate(train_token):
    if (train_cap_length[i[0]] < max_len):
        train_token_ids[n].extend(word_to_id['<pad>'] for i in range(train_cap_length[i[0]],max_len))
        

In [None]:
##Convert token ids to dictionary for train
train_token_ids_dict={}
for n,i in enumerate(train_token):
    train_token_ids_dict[i[0]]=train_token_ids[n]


In [None]:
##Pad test captions
length=[]
for (i,j) in test_cap_length.items():
    length.append(j)
max_len=max(length)

for n,i in enumerate(test_token):
    if (test_cap_length[i[0]] < max_len):
        test_token_ids[n].extend(word_to_id['<pad>'] for i in range(test_cap_length[i[0]],max_len))


In [None]:
##Convert token ids to dictionary for test
test_token_ids_dict={}
for n,i in enumerate(test_token):
    test_token_ids_dict[i[0]]=test_token_ids[n]

In [None]:
## save dictionary
np.save('data/Flickr8k/Flickr8k_text/flickr8k_dictionary.npy',np.asarray(id_to_word))

In [None]:
## save training data to single text file
with io.open('data/Flickr8k/Flickr8k_text/train_captions.txt','w',encoding='utf-8') as f:
    for i,tokens in enumerate(train_token_ids):
        f.write("%s " % train_token[i][0])
        for token in tokens:
            f.write("%i " % token)
        f.write("\n")


In [None]:
## save test data to single text file
with io.open('data/Flickr8k/Flickr8k_text/test_captions.txt','w',encoding='utf-8') as f:
    for i,tokens in enumerate(test_token_ids):
        f.write("%s " % test_token[i][0])
        for token in tokens:
            f.write("%i " % token)
        f.write("\n")

In [None]:
# ## Image preprocessing
def resize_image(image, size):
    """Resize an image to the given size."""
    return image.resize(size, Image.ANTIALIAS)

def resize_images(image_dir, output_dir, size):
    """Resize the images in 'image_dir' and save into 'output_dir'."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img, size)
                img.save(os.path.join(output_dir, image), img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_images, output_dir))


In [None]:
##Resize image
image_dir = 'data/Flickr8k/Flickr8k_Dataset/'
output_dir = 'Flickr8k_resized_image/'
image_size = [256,256]
resize_images(image_dir, output_dir, image_size)


# Constructing Dataloader

In [None]:
class Dataset(data.Dataset):
    def __init__(self,img_dir,img_id,cap_dictionary,cap_length,transform=None):
        self.img_dir=img_dir
        self.img_id=img_id
        self.transform=transform
        self.cap_dictionary=cap_dictionary
        self.cap_length=cap_length
    
    def __len__(self):
        return len(self.img_id)
    
    def __getitem__(self,index):
        img=self.img_id[index]
        img_open=Image.open(self.img_dir+img).convert('RGB')
        
        if self.transform is not None:
            img_open=self.transform(img_open)
        
        cap=np.array(self.cap_dictionary[img])
        cap_len=self.cap_length[img]
        
        return img_open,cap,cap_len


**Image Augmentation:** Image augmentation is often used for better generalization. Image augmentation means increasing images by applying edits to the image and hence increasing the training data.  here also we will be augmenting the images using torchvision.transform function. As shown below, we will be applying effects like Random crop, Random Horizontal flip, and normalizing image.

In [None]:
transform_train = transforms.Compose([ 
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

transform_test = transforms.Compose([
        transforms.RandomCrop(224),
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

In [None]:
img_dir='Flickr8k_resized_image/'
train_data=Dataset(img_dir,train_img_list,train_token_ids_dict,train_cap_length,transform_train)

test_data=Dataset(img_dir,test_img_list,test_token_ids_dict,test_cap_length,transform_test)


In [None]:
train_dataloader=data.DataLoader(train_data,batch_size=32, shuffle=True, num_workers=2)
test_dataloader=data.DataLoader(test_data,batch_size=32, shuffle=True, num_workers=2)

# Model

Till the time we have been using RNN and CNN separately in many task namely classification, translation and embedding generation. In this chapter we will be using the CNN to input the image and the information learned is passed down to the LSTM. here RNN acts as the generative model and will help in generating appropriate descriptions for the image. We will be training our machine in a supervised manner. here CNN is used as the encoder and the RNN has used the decoder.

The schematic diagram of how the task will be accomplished is given in the diagram below. This is the simplest model which has few CNN layers followed by Linear/ Dense layers. The output of the Dense layer is passed to the RNN units. RNN unit is fed with Start of sequence token <SOS> and it generate the next word. The generated word at time step t is fed to RNN at t+1 time-step and the new word is generated. This generation of the word continues until End of sequence token <EOS> is reached.
    

![](figures/image_captioning.png)

Figure. Schematic diagram of a model architecture for image captioning
Source: https://en.wikipedia.org/wiki/Bat
This sees to be simple isn't, it? Actually is very simple to make the image captioning model the only hard part is dealing with training data. To train this task we will be using the MS-COCO data which of the size 13 GB. By getting known the data-size you must have realized that this model requires a high-end machine with GPU to train. Due to data size, one cannot train this model on the Google lab. I have trained the model on my personal PC having 32 GB RAM and Nvidia 1080 Ti with 11GB VRAM attached to it. You may go ahead and use AWS or Google Cloud. Coding and converging this model is the next level of experience and will surely boost your confidence in building model with PyTorch.


**Encoder Module:** As discussed in the schematic diagram of the model architecture for image captioning, the encoder is made up of the Convolution layers. The encoder takes an image and converts it to the image context vector. Generally, to convert an image into a context vector, a pre-trained model is used. This trained model can be any network like ResNet, Descent, and VGG. A ResNet model is loaded. The last layer of such a pre-trained network is removed so that it give a n-dimensional vector for any image. This n-dimensional vector is having information related to the images and later consumed by the decoder module. 

**Decoder Module:** Decoder module is very simple, it is very similar to the decoder module we have used for language translation in chapter 4: Using RNN for NLP. The decoder module is having one LSTM layer followed by a linear transformation.  The generation takes place by using teacher forcing.



In [None]:

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-50 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids

In [None]:
encoder = EncoderCNN(1024)
decoder = DecoderRNN(1024, 1024, len(word_to_id), 1)

In [None]:
encoder.train()

In [None]:
decoder.train()

In [None]:
##Function to sort the captions and images according to caption length
def sorting(image,caption,length):
    srt=length.sort(descending=True)
    image=image[srt[1]]
    caption=caption[srt[1]]
    length=srt[0]
    
    return image,caption,length

**Appropriate Loss Function and Optimizers:** We are using cross entropy loss function. Ideally, I need to take care of the padding in batch by not calculating the loss for pad tokens, but as I want to keep this implementation simple as possible and hence using `nn.CrossEntropyLoss()` from PyTorch.
We are using Adam optimizer with learning rate 0.0001.



In [None]:
# Loss and optimizer
encoder.to(device)
decoder.to(device)
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

# Train the model

Training is having the following steps:

1. Sorting is applied to captions and images according to caption length
2. A Pytorch function pack_padded_sequence that helps in packing variable length caption to a max length of the any of the caption.
3. Passing image to the encoder and getting image vector/ context vector
4. Decoder module takes these features and generates the caption word by word
5. Loss calculation and backpropagation take place.
 

In [None]:
encoder.train()
decoder.train()
writer  =  SummaryWriter() 
train_loss=[]
time1=time.time()
epochs=30
num_iteration = 0
total_step=len(train_dataloader)
for epoch in range(epochs):
    for i, (images,captions,lengths) in enumerate(train_dataloader):
        
        images = images.to(device)
        captions = captions.to(device)
        
        images,captions,lengths=sorting(images,captions,lengths)
        
        targets = pack_padded_sequence(captions,lengths,batch_first=True)[0]
        
        
        ##Forward,backward and optimization
        features = encoder(images)
        outputs = decoder(features,captions,lengths)
        loss = criterion(outputs,targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss)
        
        # Print log info
        if i % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                      .format(epoch, epochs, i, total_step, loss.item(), np.exp(loss.item()))) 
        writer.add_scalar('Train/Loss', loss.item(), num_iteration)
        num_iteration = num_iteration+1

print('RUNNING TIME: {}'.format(time.time()-time1))

torch.save(encoder,os.path.join('./','encoder.model'))
torch.save(decoder,os.path.join('./','decoder.model'))


In [None]:
def load_image(image_path, transform=None):
    image = Image.open(image_path)
    image = image.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image


In [None]:
encoder.eval()
decoder.eval()

In [None]:

encoder.to(device)
decoder.to(device)


# Results
Following are the image where the caption is generated by taking an image. Captions generated are very accurate.

In [None]:

def getRandomFile(img_list):
  """
  Returns a random filename, chosen among the files of the given path.
  """
  #files = os.listdir(path)
  ind = random.randrange(0, len(img_list))
  return img_list[ind]



In [None]:
num_generation = 5

In [None]:
for i in range(num_generation):
    file = getRandomFile(test_img_list)
    image = load_image('Flickr8k_resized_image/'+ str(file), transform_test)
    image_tensor = image.to(device)
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    sampled_ids = sampled_ids[0].cpu().numpy()
    sampled_caption = []
    for word_id in sampled_ids:
        word = id_to_word_dict[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    sentence = ' '.join(sampled_caption)
    print (sentence)
    image = Image.open('Flickr8k_resized_image/'+ str(file))
    plt.imshow(np.asarray(image))
    plt.show()


## Training Progress

![](figures/image_captioning_progress.png)

# Performance Evaluation

In [None]:
##Test the model
encoder.eval()
decoder.eval()

test_loss=[]
time1=time.time()

total_step=len(test_dataloader)
for i, (images,captions,lengths) in enumerate(test_dataloader):
        
    images = images.to(device)
    captions = captions.to(device)
        
    images,captions,lengths=sorting(images,captions,lengths)
        
    targets = pack_padded_sequence(captions,lengths,batch_first=True)[0]
        
    with torch.no_grad():    
        ##Forward,backward and optimization
        features = encoder(images)
        outputs = decoder(features,captions,lengths)
        loss = criterion(outputs,targets)
        
    test_loss.append(loss)
        
    # Print log info
    if i % 100 == 0:
        print('Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                 .format(i, total_step, loss.item(), np.exp(loss.item()))) 
                
    # Save the model checkpoints
    '''if (i+1) % 100 == 0:
        torch.save(decoder.state_dict(), os.path.join(
                'models/flickr8k/', 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
        torch.save(encoder.state_dict(), os.path.join(
                'models/flickr8k/', 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))'''
    
print('RUNNING TIME: {}'.format(time.time()-time1))
print('PERPLEXITY: {}'.format(np.exp(loss.item())))