In [304]:
%matplotlib inline
import keras
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, LSTM, Input
from keras.callbacks import ReduceLROnPlateau

import tensorflow as tf

import collections
import numpy as np
import re
import string

print("Tensorflow version: {}".format(tf.__version__))
print("Keras version: {}".format(keras.__version__))

Tensorflow version: 1.1.0
Keras version: 2.0.8


## Helper function

In [305]:
def remove_pos_tag(word):
    return word.split('/')[0]

regex = re.compile('[-+]?([0-9]+,)?[0-9]+.?[0-9]*')
punctuation = re.compile('['+string.punctuation+']')

## Load training data & create Lookup Tables

In [315]:
raw_text = open("./data/wonderland.txt").read()
raw_text = raw_text.lower()

words = raw_text.split()
#words = [remove_pos_tag(word) for word in words if not regex.search(word)]
#words = [word for word in words if not punctuation.search(word)]
#words = words[:90000]
vocab = sorted(list(set(words)))

# Lookup tables
word_to_int = dict((c, i) for i, c in enumerate(vocab))
int_to_word = dict((i, c) for i, c in enumerate(vocab))

print("Total number of words: ", len(words))
print("Total number of unique words: ", len(vocab))

Total number of words:  26438
Total number of unique words:  4939


In [316]:
# prepare the dataset of input to output pairs encoded as integers
seq_length = 30
dataX = []
dataY = []
for i in range(0, len(words) - seq_length):
    seq_in = words[i:i + seq_length]
    seq_out = words[i + seq_length]
    dataX.append([word_to_int[word] for word in seq_in])
    dataY.append(word_to_int[seq_out])

print("Total number of patterns: ", len(dataX))

Total number of patterns:  26408


In [317]:
print(' '.join([int_to_word[value] for value in dataX[0]]))
print(int_to_word[dataY[0]])

ï»¿chapter i. down the rabbit-hole alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do: once or twice she
had


In [318]:
# reshape X to be [samples, time steps, features]
X = np.reshape(dataX, (len(dataX), seq_length, 1))

# normalize
X = X / float(len(vocab))
print(X.shape)

# one hot encode the output variable
y = keras.utils.to_categorical(dataY, len(vocab))
print(y.shape)

(26408, 30, 1)
(26408, 4939)


## Define the Model

In [325]:
inp = Input(shape=(X.shape[1], X.shape[2]))
x = LSTM(256)(inp)
x = Dropout(0.2)(x)
output = Dense(y.shape[1], activation ='softmax')(x)

generative_model = Model(inputs=inp, outputs=output)

optimizer = keras.optimizers.RMSprop(lr=0.01)
generative_model.compile(loss='categorical_crossentropy', optimizer=optimizer)

generative_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_13 (InputLayer)        (None, 30, 1)             0         
_________________________________________________________________
lstm_14 (LSTM)               (None, 256)               264192    
_________________________________________________________________
dropout_14 (Dropout)         (None, 256)               0         
_________________________________________________________________
dense_13 (Dense)             (None, 4939)              1269323   
Total params: 1,533,515
Trainable params: 1,533,515
Non-trainable params: 0
_________________________________________________________________


## Train the model

In [326]:
generative_model.fit(X, y, epochs=5, batch_size=64)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x25212501f60>

## Make some predictions

In [334]:
generated_text = []
pattern = dataX[1000][:]

print(pattern)

# generate characters
for i in range(15):
    x = np.reshape(pattern, (1, seq_length, 1))
    x = x / float(len(vocab))
    prediction = generative_model.predict(x, verbose=0)
    index = np.argmax(prediction)
    result = int_to_word[index]
    #print(result)
    pattern.append(index)
    generated_text.append(index)
    pattern = pattern[1:len(pattern)]
print("\nDone.")

[362, 2809, 4391, 666, 2649, 622, 4721, 455, 2566, 4244, 4792, 506, 4669, 2397, 2272, 4368, 4391, 2094, 2340, 3649, 586, 2340, 4502, 362, 1093, 192, 2890, 1476, 506, 4754]

Done.


In [335]:
print(pattern)
print(' '.join([remove_pos_tag(int_to_word[value]) for value in pattern]))
print(' '.join([remove_pos_tag(int_to_word[value]) for value in generated_text]))

[4368, 4391, 2094, 2340, 3649, 586, 2340, 4502, 362, 1093, 192, 2890, 1476, 506, 4754, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244, 4244]
time to hear it say, as it turned a corner, 'oh my ears and whiskers, the the the the the the the the the the the the the the the
the the the the the the the the the the the the the the the
