# Image Captioning Using Deep Learning

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pycocotools.coco import COCO

import nltk

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence
from torch.autograd import Variable

In [None]:
from utils import resize_image_due_to_pytorch_issue, load_image, resize_images, build_vocabulary, get_data_loader, show_plot_evaluation

## Set Configs

In [None]:
# image configs
IMAGE_SIZE = 256
IMAGE_PATH = './datasets/train2014/'
RESIZED_IMAGE_PATH = './datasets/resized2014/'
CROP_SIZE = 100 # cannot set to 224 (resnet input) due to pytorch issue

# caption configs
CAPTION_PATH = 'datasets/annotations/captions_train2014.json'
VOCABULARY_PATH = './datasets/vocabulary.pkl'

# model configs
EMBEDDING_SIZE = 256
HIDDEN_SIZE = 512
N_LAYERS = 1
N_EPOCHS = 5
BATCH_SIZE = 128
LR = 0.001
WEIGHT_PATH = './weights/'

In [None]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

## Prepare Images

In [None]:
resize_images(IMAGE_PATH, RESIZED_IMAGE_PATH, IMAGE_SIZE)

## Prepare Captions

In [None]:
class Vocabulary():
    
    def __init__(self):
        
        super(Vocabulary, self).__init__()
        
        self.word2index = {}
        self.index2word = {}
        self.num_words = 0
        
    def add_word(self, word):
        
        if not word in self.word2index:
            self.word2index[word] = self.num_words
            self.index2word[self.num_words] = word
            self.num_words += 1
            
    def __call__(self, word):
        
        if not word in self.word2index:
            return self.word2index['<unknown>']
        return self.word2index[word]
    
    def __len__(self):
        return len(self.word2index)

In [None]:
vocabulary = build_vocabulary(Vocabulary, min_word_count=4,
                              caption_path=CAPTION_PATH, vocabulary_path=VOCABULARY_PATH)

print(f'Total Vocabulary Size: {len(vocabulary)}')

## Set Data Loader

In [None]:
class COCODataset(torch.utils.data.Dataset):
    
    def __init__(self, image_path, coco_path, vocab, transform=None):
        
        super(COCODataset, self).__init__()
        
        self.image_path = image_path
        self.coco = COCO(coco_path)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform
        
    def __getitem__(self, index):
        
        coco = self.coco
        vocab = self.vocab
        annot_id = self.ids[index]
        image_id = coco.anns[annot_id]['image_id']
        caption = coco.anns[annot_id]['caption']
        path = coco.loadImgs(image_id)[0]['file_name']
        
        image = Image.open(os.path.join(self.image_path, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
            
        # convert caption (string) to word index
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        
        return image, target
    
    def __len__(self):
        return len(self.ids)

In [None]:
transform = transforms.Compose([transforms.RandomCrop(CROP_SIZE),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize(mean = (0.485, 0.456, 0.406),
                                                     std = (0.229, 0.224, 0.225))])

In [None]:
data_loader = get_data_loader(COCODataset, IMAGE_PATH, CAPTION_PATH, 
                              vocabulary, transform, BATCH_SIZE, shuffle=True, num_workers=2)

## Build [Image Captioning](https://arxiv.org/pdf/1411.4555.pdf) Network

In [None]:
class EncoderCNN(nn.Module):
    
    def __init__(self, embedding_size):
        
        super(EncoderCNN, self).__init__()
        
        # resnet = models.resnet152(pretrained=True) # use pre-trained resnet model
        # modules = list(resnet.children())[:-1] # remove the last fully-connected layer
        # self.resnet_layer = nn.Sequential(*modules)
        # self.fc_layer = nn.Linear(resnet.fc.in_features, embedding_size)
        
        # simple way to use pre-trained model in pytorch
        resnet = models.resnet152(pretrained=True)
        resnet.classifier = nn.Linear(in_features=resnet.fc.in_features, out_features=embedding_size)
        
        self.resnet_layer = resnet
        self.norm = nn.BatchNorm1d(embedding_size, momentum=0.01)
        
        self.init_weights()
        
    def init_weights(self):
        self.fc_layer.weight.data.normal_(0.0, 0.02)
        self.fc_layer.bias.data.fill_(0)
        
    def forward(self, images):
        
        feature_vectors = self.resnet_layer(images)
        feature_vectors = Variable(feature_vectors.data)
        feature_vectors = feature_vectors.view(feature_vectors.size(0), -1)
        feature_vectors = self.norm(self.fc_layer(feature_vectors))
        
        return feature_vectors

In [None]:
class DecoderRNN(nn.Module):
    
    def __init__(self, embedding_size, hidden_size, vocab_size, n_layers):
        
        super(DecoderRNN, self).__init__()
        
        self.embedding_layer = nn.Embedding(vocab_size, embedding_size)
        self.lstm_layer = nn.LSTM(embedding_size, hidden_size, n_layers, batch_first=True)
        self.fc_layer = nn.Linear(hidden_size, vocab_size)
        
        self.init_weights()
        
    def init_weights(self):
        self.embedding_layer.weight.data.uniform_(-0.1, 0.1)
        self.fc_layer.weight.data.uniform_(-0.1, 0.1)
        self.fc_layer.bias.data.fill_(0)
        
    def sample(self, features, states=None):
        
        sampled_ids = []
        inputs = features.unsqueeze(1)
        
        max_sampling_length = 20
        
        for i in range(max_sampling_length):
            # hiddens shape: (batch_size, 1, hidden_size), states shape: (batch_size, vocab_size)
            hiddens, states = self.lstm_layer(inputs, states)
            outputs = self.fc_layer(hiddens.squeeze(1))
            prediction = outputs.max(1)[1]
            sampled_ids.append(prediction)
            inputs = self.embedding_layer(prediction)
            inputs = inputs.unsqueeze(1)
        
        sampled_ids = torch.cat(sampled_ids, 1)
        return sampled_ids.squeeze()
        
    def forward(self, feature_vectors, source_captions, lengths):
        
        embeds = self.embedding_layer(source_captions)
        embeds = torch.cat((feature_vectors.unsqueeze(1), embeds), 1)
        packed = pack_padded_sequence(embeds, lengths, batch_first=True)
        
        hiddens, _ = self.lstm_layer(packed)
        outputs = self.fc_layer(hiddens[0])
        
        return outputs

#### Initialize Image Captioning Network

In [None]:
encoder = EncoderCNN(EMBEDDING_SIZE)
encoder.to(device)

In [None]:
decoder = DecoderRNN(EMBEDDING_SIZE, HIDDEN_SIZE, len(vocabulary), N_LAYERS)
decoder.to(device)

## Set Loss Function

In [None]:
ce_loss = nn.CrossEntropyLoss()
ce_loss.to(device)

## Set Optimizer

In [None]:
params = list(decoder.parameters()) + list(encoder.fc_layer.parameters()) + list(encoder.norm.parameters())
optimizer = torch.optim.Adam(params, lr=LR)

## Train The Network

In [None]:
losses_history = []
total_loss_print = 0; total_loss_plot = 0

print_every = 10
plot_every = 100
save_every = 1000

In [None]:
encoder.train()
decoder.train()

print('Training the network...')
for epoch in range(1, N_EPOCHS+1):
    
    for i, (images, captions, lengths) in enumerate(data_loader):
        
        # this is only the matter of pytorch issue
        images = resize_image_due_to_pytorch_issue(images)
        images = torch.from_numpy(images)
        
        # set mini-batch datasets
        images = images.to(device); captions = captions.to(device);
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
        
        encoder.zero_grad()
        decoder.zero_grad()
        
        # forward propagation
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        
        # calculate losses
        loss = ce_loss(outputs, targets)
        loss.backward()
        
        optimizer.step()
        
        # accumulate losses
        total_loss_print += loss
        total_loss_plot += loss
        
        if i % print_every == 0:
            avg_loss_print = total_loss_print / print_every
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, N_EPOCHS, i, len(data_loader), avg_loss_print, np.exp(loss.item())))
            total_loss_print = 0
            
        if epoch % plot_every == 0:
            avg_loss_plot = total_loss_plot / plot_every
            losses_history.append(avg_loss_plot)
            total_loss_plot = 0
            
        # save the model checkpoints
        if (i+1) % save_every == 0:
            torch.save(encoder.state_dict(), os.path.join(WEIGHT_PATH, f'encoder-{epoch}-{i+1}.hdf5'))
            torch.save(decoder.state_dict(), os.path.join(WEIGHT_PATH, f'decoder-{epoch}-{i+1}.hdf5'))            

In [None]:
show_plot_evaluation(losses_history, 1)

## Evaluate The Network

In [None]:
def caption_image(image_path, encoder_path, decoder_path, 
                  crop_size, embedding_size, hidden_size, vocabulary, n_layers):
    
    # prepare image
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean = (0.485, 0.456, 0.406),
                                                         std = (0.229, 0.224, 0.225))])
    
    image = load_image(image_path, crop_size, transform)
    
    # this is only the matter of pytorch issue
    images = resize_image_due_to_pytorch_issue(np.asarray(images))
    images = torch.from_numpy(images)
    
    image.to(device)
    
    # build models
    encoder = EncoderCNN(embedding_size)
    encoder.to(device)
    
    decoder = DecoderRNN(embedding_size, hidden_size, len(vocabulary), n_layers)
    decoder.to(device)
    
    # load the trained model parameters
    encoder.load_state_dict(torch.load(encoder_path))
    decoder.load_state_dict(torch.load(decoder_path))
    
    # generate an caption from the image
    feature_vectors = encoder(image)
    sampled_ids = decoder.sample(feature_vectors)
    sampled_ids = sampled_ids[0].cpu().numpy()
    
    # convert word_ids to words
    sampled_caption = []
    for word_id in sampled_ids:
        word = vocabulary.index2word[word_id]
        sampled_caption.append(word)
        
        if word == '<end>': break
            
    image_caption = ' '.join(sampled_caption).capitalize()
    
    image = Image.open(image_path)
    plt.imshow(np.asarray(image))
    plt.title(image_caption)
    plt.show()

In [None]:
SAMPLE_IMAGE_PATH = 'images/'
ENCODER_PATH = 'weights/'
DECODER_PATH = 'weights/'

In [None]:
caption_image(SAMPLE_IMAGE_PATH, ENCODER_PATH, DECODER_PATH, 
              CROP_SIZE, EMBEDDING_SIZE, HIDDEN_SIZE, vocabulary, N_LAYERS)

---