In [None]:
from collections import Counter
import numpy as np
import hickle
import json
import os
import pandas as pd
import pickle as pickle
from utils import load_coco_data

In [None]:
def load_pickle(path):
    with open(path, 'rb') as f:
        file = pickle.load(f)
        print('Loaded %s..' %path)
        return file  

def save_pickle(data, path):
    with open(path, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        print('Saved %s..' %path)

In [None]:
def process_caption_data(caption_file, image_dir, max_length):
    with open(caption_file) as f:
        caption_data = json.load(f)
    
    id_to_filename = {image['id']: image['file_name'] for image in caption_data['images']}
    
    data = []
    for annotation in caption_data['annotations']:
        image_id = annotation['image_id']
        annotation['file_name'] = os.path.join(image_dir, id_to_filename[image_id])
        data += [annotation]
        
    caption_data = pd.DataFrame.from_dict(data)
    del caption_data['id']
    caption_data.sort_values(by='image_id', inplace=True)
    caption_data = caption_data.reset_index(drop=True)
    
    del_idx = []
    sum = 0
    for i, caption in enumerate(caption_data['caption']):
        caption = caption.replace('.','').replace(',','').replace("'","").replace('"','')
        caption = caption.replace('&','and').replace('(','').replace(")","").replace('-',' ')
        caption = " ".join(caption.split())  # replace multiple spaces
        
        caption_data.set_value(i, 'caption', caption.lower())
        sum += len(caption.split(" "))
        if len(caption.split(" ")) > max_length:
            del_idx.append(i)
    
    print("The number of captions before deletion: {}".format(len(caption_data)))
    caption_data = caption_data.drop(caption_data.index[del_idx])
    caption_data = caption_data.reset_index(drop=True)
    print("The number of captions after deletion: {}".format(len(caption_data)))
    
    # remove that caption count is not 5
    gp_data = caption_data.groupby(['image_id']).count()
    idx = caption_data['image_id'].isin(gp_data[gp_data['caption'] <= 5].index) # ==
    caption_data = caption_data[idx]
    caption_data = caption_data.reset_index(drop=True)
    print("The number of captions after deletion: {}".format(len(caption_data)))
    
    caption_data['caption'] = '<start> ' + caption_data['caption'] + ' <end>'
    
    return caption_data    

In [None]:
def build_vocab(annotations, threshold=1):
    counter = Counter()
    max_len = 0
    for i, caption in enumerate(annotations['caption']):
        words = caption.split(' ') # caption contrains only lower-case words
        for w in words:
            counter[w] +=1
        
        if len(caption.split(" ")) > max_len:
            max_len = len(caption.split(" "))

    vocab = [word for word in counter if counter[word] >= threshold]
    print ('Filtered %d words to %d words with word count threshold %d.' % (len(counter), len(vocab), threshold))

    word_to_idx = {u'<NULL>': 0, u'<START>': 1, u'<END>': 2}
    idx = 3
    for word in vocab:
        word_to_idx[word] = idx
        idx += 1
    print("Max length of caption: ", max_len)
    return word_to_idx

In [None]:
max_length=50

In [None]:
caption_file = 'dataset/annotations/captions_train2014.json'
image_dir = 'dataset/train2014/'

train_dataset = process_caption_data(caption_file, image_dir, max_length)
val_dataset = process_caption_data(caption_file='dataset/annotations/captions_val2014.json', 
                                   image_dir='dataset/val2014', max_length=max_length)

In [None]:
save_pickle(train_dataset, 'data/train/train.annotations_3.pkl')
save_pickle(val_dataset[:val_cutoff], 'data/val/val.annotations_3.pkl')
save_pickle(val_dataset[val_cutoff:test_cutoff], 'data/test/test.annotations_3.pkl')

In [None]:
annotations = load_pickle('data/train/train.annotations.pkl')

word_to_idx = build_vocab(annotations=annotations, threshold=1)
save_pickle(word_to_idx, 'data/train/word_to_idx.pkl')