Permalink
Browse files

Add LSTM text generation example

  • Loading branch information...
fchollet committed Jun 16, 2015
1 parent bf4822f commit d2b229df2ea0bab712379c418115bc44508bc6f9
Showing with 103 additions and 0 deletions.
  1. +103 −0 examples/lstm_text_generation.py
@@ -0,0 +1,103 @@
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.layers.recurrent import LSTM
from keras.datasets.data_utils import get_file
import numpy as np
import random, sys
'''
Example script to generate text from Nietzsche's writings.
At least 20 epochs are required before the generated text
starts sounding coherent.
It is recommended to run this script on GPU, as recurrent
networks are quite computationally intensive.
If you try this script on new data, make sure your corpus
has at least ~100k characters. ~1M is better.
'''
path = get_file('nietzsche.txt', origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt")
text = open(path).read().lower()
print('corpus length:', len(text))
chars = set(text)
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
# cut the text in semi-redundant sequences of 30 characters
maxlen = 30
step = 5
sentences = []
next_chars = []
for i in range(0, len(text) - maxlen, step):
sentences.append(text[i : i + maxlen])
next_chars.append(text[i + maxlen])
print('nb sequences:', len(sentences))
print('Vectorization...')
X = np.zeros((len(sentences), maxlen, len(chars)))
y = np.zeros((len(sentences), len(chars)))
for i, sentence in enumerate(sentences):
for t, char in enumerate(sentence):
X[i, t, char_indices[char]] = 1.
y[i, char_indices[next_chars[i]]] = 1.
# build the model: 2 stacked LSTM
print('Build model...')
model = Sequential()
model.add(LSTM(len(chars), 512, return_sequences=True))
model.add(LSTM(512, 512, return_sequences=False))
model.add(Dense(512, len(chars)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# helper function to sample an index from a probability array
def sample(a, diversity=0.75):
if random.random() > diversity:
return np.argmax(a)
while 1:
i = random.randint(0, len(a)-1)
if a[i] > random.random():
return i
# train the model, output generated text after each epoch
for iteration in range(1, 30):
print()
print('-' * 50)
print('Iteration', iteration)
model.fit(X, y, batch_size=128, nb_epoch=1)
start_index = random.randint(0, len(text) - maxlen - 1)
for diversity in [0.4, 0.7, 1.]:
print()
print('----- diversity:', diversity)
generated = ''
sentence = text[start_index : start_index + maxlen]
generated += sentence
print('----- Generating with seed: "' + sentence + '"')
sys.stdout.write(generated)
for iteration in range(400):
x = np.zeros((1, maxlen, len(chars)))
for t, char in enumerate(sentence):
x[0, t, char_indices[char]] = 1.
preds = model.predict(x, verbose=0)[0]
next_index = sample(preds, diversity)
next_char = indices_char[next_index]
generated += next_char
sentence = sentence[1:] + next_char
sys.stdout.write(next_char)
sys.stdout.flush()
print()

0 comments on commit d2b229d

Please sign in to comment.