<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fgeneration/applications/generation/image_captioning/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kaggle



In [None]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"ravirajag","key":"c7a696c0995b2afcedae640fbc1c7c05"}'}

In [None]:
!mkdir ~/.kaggle

mkdir: cannot create directory ‘/root/.kaggle’: File exists


In [None]:
!cp kaggle.json ~/.kaggle

In [None]:
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d adityajn105/flickr8k

Downloading flickr8k.zip to /content
 99% 1.02G/1.04G [00:06<00:00, 220MB/s]
100% 1.04G/1.04G [00:06<00:00, 168MB/s]


In [None]:
!ls

flickr8k.zip  kaggle.json  sample_data


In [None]:
!unzip flickr8k.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: Images/2844963839_ff09cdb81f.jpg  
  inflating: Images/2845246160_d0d1bbd6f0.jpg  
  inflating: Images/2845691057_d4ab89d889.jpg  
  inflating: Images/2845845721_d0bc113ff7.jpg  
  inflating: Images/2846037553_1a1de50709.jpg  
  inflating: Images/2846785268_904c5fcf9f.jpg  
  inflating: Images/2846843520_b0e6211478.jpg  
  inflating: Images/2847514745_9a35493023.jpg  
  inflating: Images/2847615962_c330bded6e.jpg  
  inflating: Images/2847859796_4d9cb0d31f.jpg  
  inflating: Images/2848266893_9693c66275.jpg  
  inflating: Images/2848571082_26454cb981.jpg  
  inflating: Images/2848895544_6d06210e9d.jpg  
  inflating: Images/2848977044_446a31d86e.jpg  
  inflating: Images/2849194983_2968c72832.jpg  
  inflating: Images/2850719435_221f15e951.jpg  
  inflating: Images/2851198725_37b6027625.jpg  
  inflating: Images/2851304910_b5721199bc.jpg  
  inflating: Images/2851931813_eaf8ed7be3.jpg  
  inflating: Images/285

In [None]:
!ls -lah

total 1.1G
drwxr-xr-x 1 root root 4.0K Jun 29 03:20 .
drwxr-xr-x 1 root root 4.0K Jun 29 03:06 ..
-rw-r--r-- 1 root root 3.2M Apr 27 07:29 captions.txt
drwxr-xr-x 1 root root 4.0K Jun 25 17:02 .config
-rw-r--r-- 1 root root 1.1G Jun 29 03:19 flickr8k.zip
drwxr-xr-x 2 root root 420K Jun 29 03:20 Images
-rw-r--r-- 1 root root   65 Jun 29 03:19 kaggle.json
drwxr-xr-x 1 root root 4.0K Jun 17 16:18 sample_data


### Imports

In [None]:
import os
import h5py
import json
import numpy as np

import torch
import torch.nn as nn
import torchvision

from torch.utils.data import Dataset

from scipy.misc import imread, imresize

from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
def create_input_files(dataset, karapathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder, max_len=100):

    # read the karapathy json file
    with open(karapathy_json_path, 'r') as f:
        data = json.load(f)
    
    # Read image paths and captions for each image
    train_image_paths = []
    train_image_captions = []
    val_image_paths = []
    val_image_captions = []
    test_image_paths = []
    test_image_captions = []
    word_freq = Counter()

    for img in data['images']:
        for c in img['sentences']:
            word_freq.update(c['tokens'])
            if len(c['tokens']) <= max_len:
                captions.append(c['tokens'])
        
        if len(captions) == 0:
            continue
        
        path = os.path.join(image_folder, img['filename'])

        if img['split'] in {'train', 'restval'}:
            train_image_paths.append(path)
            train_image_captions.append(captions)
        elif img['split'] in {'val'}:
            val_image_paths.append(path)
            val_image_captions.append(captions)
        elif img['split'] in {'test'}:
            test_image_paths.append(path)
            test_image_captions.append(captions)
    
    assert len(train_image_paths) == len(train_image_captions)
    assert len(val_image_paths) == len(val_image_captions)
    assert len(test_image_paths) == len(test_image_captions)

    # create word map
    words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
    word_map = {k: v+1 for v, k in enumerate(words)}
    word_map['<unk>'] = len(word_map) + 1
    word_map['<start>'] = len(word_map) + 1
    word_map['<end>'] = len(word_map) + 1
    word_map['<pad>'] = 0
    
    # Create a base/root name for all output files
    base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq'

    # Save word map to a JSON
    with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j:
        json.dump(word_map, j)
    
    # Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files
    for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'),
                                   (val_image_paths, val_image_captions, 'VAL'),
                                   (test_image_paths, test_image_captions, 'TEST')]:
        with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h:
            h.attrs['captions_per_image'] = captions_per_image

            # create dataset inside HDF5 file to store the images
            images = h.create_dataset('images', (len(images), 3, 256, 256), dtype='uint8')
            
            print(f"\nReading {split} images and captions, storing to file...\n")

            enc_captions = []
            caplens = []

            for i, path in enumerate(tqdm(impaths)):

                # sample captions
                if len(imcaps[i]) < captions_per_image:
                    captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))]
                else:
                    captions = sample(imcaps[i], k=captions_per_image)
                
                assert len(captions) == captions_per_image

                # read images
                img = imread(impaths[i])
                if len(img.shape) == 2:
                    img = img[:, :, np.newaxis]
                    img = np.concatenate([img, img, img], axis=2)

                img = imresize(img, (256, 256))

                # transpose the image to make channels as inital component
                img = img.transpose(2, 0, 1)
                assert img.shape == (3, 256, 256)
                assert np.max(img) <= 255

                # save image to HDF5 file
                images[i] = img

                for j, c in enumerate(captions):
                    # encode captions
                    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in c] + [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(c))

                    # caption length
                    c_len = len(c) + 2
                    enc_captions.append(enc_c)
                    caplens.append(c_len)

            assert images.shape[0] * captions_per_image = len(enc_captions) == len(caplens)

            # Save encoded captions and their lengths to JSON files
            with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j:
                json.dump(enc_captions, j)

            with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'w') as j:
                json.dump(caplens, j)


In [None]:
class CaptionDataset(Dataset):
    def __init__(self, data_folder, data_name, split, transform=None):
        super().__init__()

        self.split = split
        assert split in {"TRAIN", "VAL", "TEST"}

        # open hdf5 file
        with h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5'), 'r') as h:
            # images
            self.imgs = h['images']
            # captions per image
            self.cpi = h.attrs['captions_per_image']

        # Load encoded captions (completely into memory)
        with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as j:
            self.captions = json.load(j)

        # Load caption lengths (completely into memory)
        with open(os.path.join(data_folder, self.split + '_CAPLENS_' + data_name + '.json'), 'r') as j:
            self.caplens = json.load(j)
        
        self.transform = transform

        self.dataset_size = len(self.captions)
    
    def __getitem__(self, item):
        img = torch.FloatTensor(self.imgs[item // self.cpi] / 255.)

        if self.transform is not None:
            img = self.transform(img)
        
        caption = torch.LongTensor(self.captions[item])
        caplen = torch.LongTensor([self.caplens[item]])

        if self.split is "TRAIN":
            return img, caption, caplen
        else:
            # For validation / testing, return all captions of the image to find BLUE score
            all_captions = torch.LongTensor(
                self.captions[((i // self.cpi) * self.cpi):(((i // self.cpi) * self.cpi) + self.cpi)]
            )
            return img, caption, caplen, all_captions
    
    def __len__(self):
        self.dataset_size

In [None]:
class Encoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super().__init__()

        self.enc_image_size = encoded_image_size

        resnet = torchvision.models.resnet101(pretrained=True)

        # remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        # resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
    
        self.fine_tune()
    
    def forward(self, images):
        # images => [batch_size, 3, image_size, image_size]

        out = self.resnet(images)
        # out => [batch_size, 2048, image_size / 32, image_size / 32]

        out = self.adaptive_pool(out)
        # out => [batch_size, 2048, encoded_image_size, encoded_image_size]

        out = out.permute(0, 2, 3, 1)
        # out => [batch_size, encoded_image_size, encoded_image_size, 2048]

        return out
    
    def fine_tune(self, fine_tune=True):
        for p in self.resnet.parameters():
            p.requires_grad = False
        
        # fine-tuning
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()

        self.encoder_attn = nn.Linear(encoder_dim, attention_dim)
        self.decoder_attn = nn.Linear(decoder_dim, attention_dim)
        self.full_attn = nn.Linear(attention_dim, 1)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, encoder_out, decoder_hidden):
        # encoder_out => [batch_size, num_pixels, encoder_dim]
        # decoder_hidden => [batch_size, decoder_dim]

        enc_attn = self.encoder_attn(encoder_out)
        # enc_attn => [batch_size, num_pixels, attn_dim]

        dec_attn = self.decoder_attn(decoder_hidden)
        # dec_attn => [batch_size, attn_dim]

        dec_attn = dec_attn.unsqueeze(1)
        # dec_attn => [batch_size, 1, attn_dim]

        attn = self.full_attn(self.relu(enc_attn + dec_attn))
        # attn => [batch_size, num_pixels, 1]
        
        attn = attn.squeeze(2)
        # attn => [batch_size, num_pixels]

        alpha = self.softmax(attn)
        # alpha => [batch_size, num_pixels]

        weighted = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        # weighted => [batch_size, encoder_dim]

        return weighted, alpha

In [None]:
class Decoder(nn.Module):
    def __init__(self, attention_dim, emb_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super().__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = emb_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)

        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.decode_step = nn.LSTMCell(emb_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.fc = nn.Linear(decoder_dim, vocab_size)
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(dropout)

        self.init_weights()
    
    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform(-0.1, 0.1)
    
    def load_pretrained_embedding(self, embeddings):
        self.embedding.weight = nn.Parameter(embeddings)
    
    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embeddings.parameters():
            p.requires_grad = fine_tune
    
    def init_hidden_state(self, encoder_out):
        # encoder_out => [batch_size, num_pixels, encoder_dim]

        mean_encoder_out = torch.mean(encoder_out, dim=1)
        # mean_encoder_out => [batch_size, encoder_dim]

        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        # h, c => [batch_size, decoder_dim]

        return h, c
    
    def forward(self, encoder_out, encoded_captions, caption_lengths):
        # encoder_out => [batch_size, enc_image_size, enc_image_size, encoder_dim]
        # encoded_captions => [batch_size, max_caption_len]
        # caption_lengths => [batch_size, 1]

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten Image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        # encoder_out => [batch_size, num_pixels, encoder_dim]
        num_pixels = encoder_out.size(1)

        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # embedding
        embeddings = self.embedding(encoded_captions)
        # embeddings => [batch_size, max_caption_len, emb_dim]

        # initialize lstm states
        h, c = self.init_hidden_state(encoder_out)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoded_out[:batch_size_t], h[:batch_size_t])
            
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            # gate => [batch_size_t, encoder_dim]

            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t])
            )
            # h, c => [batch_size_t, decoder_dim]

            preds = self.fc(self.dropout(h))
            # preds => [batch_size_t, vocab_size]

            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [None]:
def train():
    pass

In [None]:
def validate():
    pass