In [24]:
import sys
import pickle
from collections import Counter, OrderedDict
import nltk
import numpy as np
from tqdm import tqdm
from keras.models import load_model
from nltk.tokenize import word_tokenize

In [10]:
# Download tokenizer models if needed
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/qhduan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
keras_model = 'keras_image_caption_model.dat'

In [4]:
model = load_model(keras_model)

In [13]:
# Read image file_name and their captions
file_name_caption = pickle.load(open( "file_name_caption.bat", "rb" ))
# Read image file_name and their VGG 16 weights
file_name_images = pickle.load(open( "file_name_images.bat", "rb" ))

train_size = len(file_name_caption)
print('train_size', train_size)

START = '<start>'
END = '<end>'
UNK = '<unk>'
PAD = '<pad>'

min_count = 3
max_len = 0
train_words_size = 0
vocabulary = Counter()
for caption in tqdm(file_name_caption.values(), file=sys.stdout, total=len(file_name_caption)):
    sent = word_tokenize(caption)
    vocabulary.update(sent)
    train_words_size += len(sent)
    if len(sent) > max_len: max_len = len(sent)

vocabulary = [k for k, v in vocabulary.items() if v >= min_count]
vocabulary = sorted(list(set(vocabulary)))
word_index = OrderedDict()
index_word = OrderedDict()
for index, word in enumerate([START, END, UNK, PAD] + vocabulary):
    word_index[word] = index
    index_word[index] = word
vocabulary_size = len(word_index)

print('vocabulary_size', vocabulary_size)
print('max_len', max_len)
print('train_words_size', train_words_size)

train_size 123287
100%|██████████| 123287/123287 [00:12<00:00, 10198.75it/s]
vocabulary_size 8075
max_len 55
train_words_size 1393524


In [14]:
file_name_list = list(file_name_images.keys())

In [18]:
def sent_to_index(input_sent, word_index, max_len):
    padding_size = max_len - len(input_sent)

    input_sent = input_sent + padding_size * [PAD]
    input_sent_index = []
    for w in input_sent:
        if w in word_index:
            input_sent_index.append(word_index[w])
        else:
            input_sent_index.append(word_index[UNK])
    return input_sent_index

In [39]:
def predict_weights(weights, word_index, index_word, max_len):
    sent = [START]
    while True:
        sent_index = sent_to_index(sent, word_index, max_len)
        index = model.predict([np.asarray([weights]), np.asarray([sent_index])]).argmax()
        if index in index_word:
            word = index_word[index]
        else:
            word = UNK
        sent.append(word)
        if word == END or len(sent) > max_len:
            break
    return sent

In [44]:
file_name_list[1]

'COCO_val2014_000000119081.jpg'

In [45]:
predict_weights(file_name_images[file_name_list[1]], word_index, index_word, max_len)

['<start>',
 'A',
 'cat',
 'is',
 'lounging',
 'on',
 'a',
 'bed',
 'with',
 'a',
 'lamp',
 '.',
 "''",
 '.',
 '.',
 'and',
 'a',
 'pair',
 'of',
 'scissors',
 '.',
 '.',
 'and',
 'a',
 'pair',
 'of',
 'paper',
 '.',
 '.',
 'and',
 'a',
 'pair',
 'of',
 'paper',
 '.',
 'and',
 'a',
 'mouse',
 '.',
 '<unk>',
 'and',
 'a',
 'mouse',
 'control',
 '.',
 "''",
 '.',
 '<unk>',
 'a',
 'mouse',
 '.',
 '<unk>',
 '.',
 "''",
 '.',
 'control']

In [30]:
ret

50

In [32]:
index_word[ret.argmax()]

'A'