# Download datasets

In [None]:
!wget https://bashupload.com/D7AwJ/train2014.zip
!unzip train2014.zip
!rm train2014.zip
!wget https://bashupload.com/BFu0A/val2014.zip
!unzip val2014.zip
!rm val2014.zip
!wget http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip
!unzip caption_datasets.zip
!rm caption_datasets.zip
!rm dataset_flickr30k.json
!rm dataset_flickr8k.json

In [None]:
!ls train2014

#Import

In [None]:
import os
import numpy as np
import h5py
import json
import torch
from PIL import Image
from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample
import cv2

#Data preprocessing

In [None]:
# Datasets:
#   train2014 is a folder of image files for training
#   val2014 is a folder of image files for validation
#   dataset_coco.json is a JSON file that tells you {image -> captions}

## Data loading


In [None]:
# Load JSON file into dict
json_path = 'dataset_coco.json'
with open(json_path) as json_file: 
    data = json.load(json_file)
print(data['images'][0])

# Understand how each image is captioned
# 'filename' is the image name
# 'filepath' is the folder name
# 'imgid' is the id of the image
# 'sentences' is a list of the human captioning
# 'tokens' is a list of words

In [None]:
data['images'][0]

In [None]:
data['images'][0]["sentences"]

In [None]:
type(data['images'][0]['filename'])

In [None]:
# Each image may have multiple captions
# to reduce the bias we are introducing, 
# let's use the same number of captions per image
captions_per_image=5

# Maximum number of words in a sentence
# If the sentence has more than max_len words, we skip it
# If the sentence has less than max_len words, we pad it with <pad>
max_len=50

# From json object to a list of (image_path, captions) pairs 
# note: captions should be a list of word lists
train_img_cap_pairs = []
val_img_cap_pairs = []

# It contains all distinct words
word_set = set()

for img_obj in data['images']:
    captions = []
    for caption in img_obj['sentences']:
        word_set.update(caption['tokens'])
        if len(caption['tokens']) <= max_len:
            captions.append(caption['tokens'])

    # If captions is empty, what should we do here?
    if len(captions) == 0:
      continue

    img_path = img_obj['filepath'] + '/' + img_obj['filename']

    # What if this image cannot be found?
    if not os.path.exists(img_path): 
      continue

    # Append the pair to the list
    if img_obj['split'] == 'train':
      train_img_cap_pairs.append([img_path,captions])
    elif img_obj['split'] == 'val':
      val_img_cap_pairs.append([img_path,captions])

In [None]:
print(train_img_cap_pairs)

## Data tranformation

In [None]:
# HDF5: HDF5 is a unique technology suite that makes possible the management
# of extremely large and complex data collections.

# 1. We will create 2 hdf5 files: 
#      train_images.hdf5, val_images.hdf5
# 2. We will create 5 json files: 
#      word_map.json -- contains a (word -> number) hash object
#      train_captions.json -- contains a list of encoded training captions
#      val_captions.json -- contains a list of encoded validation captions
#      train_caption_length.json -- contains a list of training caption lengths
#      val_caption_length.json -- contains a list of validation caption lengths

In [None]:
# Word Encoding
# word_map: word    -> number (starting from 1)
#           <pad>   -> 0·
#           <start> -> the second highest number
#           <end>   -> the highest number
word_map = {k:idk+1 for idk, k in enumerate(word_set)}
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

# Save word map to a JSON
with open(os.path.join('word_map.json'), 'w') as j:
  json.dump(word_map, j)

In [None]:
for img_cap_pairs, split in [[train_img_cap_pairs,'train'], [val_img_cap_pairs, 'val']]:
    # Save encoded captions and their lengths to JSON files
    h5py_path = os.path.join(split + '_images.hdf5')
    
    # remove it if the path exists
    if os.path.exists(h5py_path): 
      os.remove(h5py_path)

    with h5py.File(h5py_path, 'a') as h:
        # Make a note of the number of captions we are sampling per image
        h.attrs['captions_per_image'] = captions_per_image

        # Create dataset inside HDF5 file to store images
        # we do channel first for the image
        images = h.create_dataset('images', (len(img_cap_pairs), 3, 256, 256), dtype='uint8')

        enc_captions = []
        caplens = []
        for index, img_cap_pair in enumerate(img_cap_pairs):
            img_path, captions = img_cap_pair

            if len(captions) < captions_per_image:
                # add some captions by randomly sampling from captions
                captions = captions + sample(captions, captions_per_image-len(captions))
            else:
                # randomly sample k from captions
                captions = sample(captions, captions_per_image)

            # Sanity check
            assert len(captions) == captions_per_image

            # Read image and transform it into (3, 256, 256)
            # Hint: use cv2, you will need to read, resize and transpose
            img = cv2.imread(img_path, 1)
            img = cv2.resize(img, (256,256))
            img = img.transpose(2,0,1)

            assert img.shape == (3, 256, 256)

            # Save image to HDF5 file
            images[index] = img

            for idx, caption in enumerate(captions):
                # Encode captions
                #   a list of numbers
                #   Format should be <start> word1 word2 ... wordN <end> <pad> <pad>...
                #   The total length should be equal to max_len
                enc_c = [word_map['<start>']] + [word_map[word] for word in caption] + \
                        [word_map['<end>']] + [word_map['<pad>']]*(max_len - len(caption))
                enc_captions.append(enc_c)
                caplens.append(len(caption) + 2)                

    with open(os.path.join(split + '_captions.json'), 'w') as j:
        json.dump(enc_captions, j)

    with open(os.path.join(split + '_caption_length.json'), 'w') as j:
        json.dump(caplens, j)

# Sanity check
print('caption length:', caplens[-1])
print('caption:', caption)
print('caption encoding:', enc_c)


caption length: 12
caption: ['a', 'horse', 'stands', 'on', 'grass', 'and', 'looks', 'at', 'the', 'camera']
caption encoding: [27930, 25913, 12943, 22226, 8382, 16078, 23802, 15480, 11368, 1448, 19596, 27931, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
