# Image Captioning Using Deep Learning With Attention Mechanism

In [None]:
import os
import h5py
import json
import random
import numpy as np
from scipy.misc import imread, imresize
from collections import Counter
from tqdm import tqdm_notebook

In [None]:
import torch
import torch.nn as nn
import torchvision

## Set Configs

In [None]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

## Set Utils

In [None]:
def create_input_files(datasets, karpathy_json_path, image_dir, output_dir, captions_per_image, min_word_freq, max_length=100):
    
    assert datasets in {'coco', 'flickr8k', 'flickr30k'}
    
    # read Karpathy's json
    with open(karpathy_json_path, 'r') as file:
        data = json.load(file)
        
    # read image paths and captions for each image
    train_image_paths = []
    train_image_captions = []
    valid_image_paths = []
    valid_image_captions = []
    test_image_paths = []
    test_image_captions = []
    
    word_freq = Counter()
    
    for image in tqdm_notebook(data['images']):
        captions = []
        for sentence in image['sentences']:
            word_freq.update(sentence['tokens'])
            if len(sentence['tokens']) <= max_length:
                captions.append(sentence['tokens'])
                
        if len(captions) == 0:
            continue
            
        path = os.path.join(image_dir, image['filepath'], image['filename']) if datasets == 'coco' \
                                                                             else os.path.join(image_dir, image['filename'])
        
        if image['split'] in {'train', 'restval'}:
            train_image_paths.append(path)
            train_image_captions.append(captions)
        elif image['split'] in {'val'}:
            valid_image_paths.append(path)
            valid_image_captions.append(captions)
        elif image['split'] in {'test'}:
            test_image_paths.append(path)
            test_image_captions.append(captions)
            
    # sanity check
    assert len(train_image_paths) == len(train_image_captions)
    assert len(valid_image_paths) == len(valid_image_captions)
    assert len(test_image_paths) == len(test_image_captions)
    
    # create vocabulary
    words = [word for word in word_freq.keys() if word_freq[word] > min_word_freq]
    word_vocab = { key: value + 1 for value, key in enumerate(words)}
    word_vocab['<unk>'] = len(word_vocab) + 1
    word_vocab['<start>'] = len(word_vocab) + 1
    word_vocab['<end>'] = len(word_vocab) + 1
    word_vocab['<pad>'] = 0
    
    # create a base/ root name for all output files
    base_filename = datasets + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq'
    
    # save word vocabulary to a JSON
    with open(os.path.join(output_dir, 'data/WORD_VOCAB_' + base_filename + '.json'), 'w') as file:
        json.dump(word_vocab, file)
        
    # sample captions for each image, save images to HDF5 file and captions and their lengths to JSON files
    random.seed(9)
    for image_paths, image_captions, split in [(train_image_paths, train_image_captions, 'TRAIN'),
                                              (valid_image_paths, valid_image_captions, 'VALID'),
                                              (test_image_paths, test_image_captions, 'TEST')]:
        
        with h5py.File(os.path.join(output_dir, 'data/' + split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as file:
            
            # make a note of the number of captions we are sampling per image
            file.attrs['captions_per_image'] = captions_per_image
            
            # create dataset inside HDF5 file to store images
            images = file.create_dataset('./datasets/images', (len(image_paths), 3, 256, 256), dtype='uint8')
            
            print(f'\nReading {split} images and captions, storing to file...\n')
            
            encoded_captions = []
            captions_length = []
            
            for i, path in enumerate(image_paths):
                
                # sample captions
                if len(image_captions[i]) < captions_per_image:
                    captions = image_captions[i] + [random.choice(image_captions[i]) for _ in range(captions_per_image - len(image_captions[i]))]
                else:
                    captions = random.sample(image_captions[i], k=captions_per_image)
                    
                # sanity check
                assert len(captions) == captions_per_image
                
                # read images
                image = imread(image_paths[i])
                if len(image.shape) == 2:
                    image = image[:, :, np.newaxis]
                    image = np.concatenate([image, image, image], axis=2)
                image = imresize(image, (256, 256))
                image = image.transpose(2, 0, 1)
                
                # sanity check
                assert image.shape == (3, 256, 256)
                assert np.max(image) <= 255
                
                # save image to HDF5 file
                images[i] = image
                
                for j, caption in enumerate(captions):
                    # encode captions
                    encoded_caption = [word_vocab['<start>']] + [word_vocab.get(word, word_vocab['<unk>']) for word in caption] +\
                                      [word_vocab['<end>']] + [word_vocab['<pad>']] * (max_length - len(caption))
                        
                    # find caption lengths
                    caption_length = len(caption) + 2
                    
                    encoded_captions.append(encoded_caption)
                    captions_length.append(caption_length)
            
            # sanity check
            assert images.shape[0] * captions_per_image == len(encoded_captions) == len(captions_length)
            
            # save encoded captions and their lengths to JSON files
            with open(os.path.join(output_dir, 'data/' + split + '_CAPTIONS_' + base_filename + '.json'), 'w') as file:
                json.dump(encoded_captions, file)
            
            with open(os.path.join(output_dir, 'data/' + split + '_CAPLENS_' + base_filename + '.json'), 'w') as file:
                json.dump(captions_length, file)

In [None]:
create_input_files(datasets='coco', karpathy_json_path='./datasets/karpathy_captions/datasets_coco.json',
                   image_dir='./datasets/', output_dir='./datasets/',
                   captions_per_image=5,
                   min_word_freq=5,
                   max_length=50)

## Set Data Loader

In [None]:
class CaptionDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_folder, data_name, split, transform=None):
        
        super(CaptionDataset, self).__init__()
        
        self.split = split
        assert self.split in {'TRAIN', 'VALID', 'TEST'}
        
        # open hdf5 file where images are stored
        self.hdf5 = h5py.File(os.path.join(data_folder, '/data' + self.split + '_IMAGES_' + data_name + '.hdf5'), 'r')
        self.images = self.hdf5['images']
        
        # captions per image
        self.cpi = self.hdf5.attrs['captions_per_image']
        
        # load encoded captions (completely into memory)
        with open(os.path.join(data_folder, 'data/' + self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as file:
            self.captions = json.load(file)
            
        # load captions lengths (completely into memory)
        with open(os.path.join(data_folder, 'data'/ + self.split + '_CAPLENS_' + data_name + '.json'), 'r') as file:
            self.caplens = json.load(file)
            
        # pytorch transformation pipeline for the image (normalizing, etc.)
        self.transform = transform
        
        # total number of data points
        self.dataset_size = len(self.captions)
        
    def __getitem_(self, i):
        
        # remember, the Nth caption corresponds to the (N // captions_per_image)th image
        image = torch.FloatTensor(self.images[i // self.cpi] / 255.)
        if self.transform is not None:
            image = self.transform(image)
            
        caption = torch.LongTensor(self.captions[i])
        caplen = torch.LongTensor([self.caplens[i]])
        
        if self.split is 'TRAIN':
            return image, caption, caplen
        else:
            # for validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score
            all_captions = torch.LongTensor(
                self.captions[((i // self.cpi) * self.cpi) : (((i // self.cpi) * self.cpi) + self.cpi)])
            return image, caption, caplen, all_captions
        
    def __len__(self):
        return self.dataset_size

## Build [Image Captioning](https://arxiv.org/pdf/1411.4555.pdf) Network with [Attention](https://arxiv.org/pdf/1502.03044.pdf)

In [None]:
class EncoderCNN(nn.Module):
    
    def __init__(self):
        
        super(EncoderCNN, self).__init__()

In [None]:
class Attention(nn.Module):
    def __init__(self):
        
        super(Attention, self).__init__()

In [None]:
class AttentionDecoderRNN(nn.Module):
    
    def __init__(self):
        
        super(AttentionDecoderRNN, self).__init__()

---