In [105]:
import logging
from sys import stdout

from keras.models import Sequential
from keras.layers.core import Dense, Flatten
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM, SimpleRNN
from keras.layers.wrappers import TimeDistributed
from nltk.tokenize import word_tokenize
import numpy as np
import pandas as pd
import progressbar

logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',
                     level=logging.INFO, stream=stdout)
logger = logging.getLogger(__name__)

In [6]:
df = pd.read_json('/mnt/W/Users/Alt/Documents/CMU/11777/data/MSCOCO/annotations/captions_val2014.json',
                  typ='series')

In [7]:
captions = [entry['caption'].rstrip() for entry in df['annotations']]

In [30]:
tokenized_captions = [word_tokenize(caption) for caption in captions]

In [73]:
vocabulary = set(word for caption in tokenized_captions for word in caption)
encoder = {word: code for code, word in enumerate(sorted(vocabulary))}
X = [np.array([encoder[word] for word in caption]) for caption in tokenized_captions]

In [74]:
k = len(vocabulary)
Ik = np.eye(k)

In [114]:
# Inspired from https://chsasank.github.io/spoken-language-understanding.html and https://stackoverflow.com/questions/39142665/keras-lstm-language-model-using-embeddings

model = Sequential()
model.add(Embedding(input_dim=k,
                    output_dim=50))
model.add(SimpleRNN(units=50,
                    return_sequences=True))
model.add(TimeDistributed(Dense(units=k, activation='softmax')))

model.compile(loss="categorical_crossentropy", optimizer="rmsprop")

In [115]:
n_epochs = 1

X_subset = X[:100]

for i in range(n_epochs):
    logger.info('Training epoch %d', i)
    bar = progressbar.ProgressBar()
    for caption_repr in bar(X_subset):
        if len(caption_repr) > 1:
            input_ = caption_repr[:-1][np.newaxis, :]
            label = Ik[caption_repr[1:]][np.newaxis, :]
            model.train_on_batch(input_,
                                 label)

2017-10-09 18:00:06,349 | INFO : Training epoch 0


100% (100 of 100) |#######################| Elapsed Time: 0:00:07 Time: 0:00:07
