In [0]:
import keras

filepath = keras.utils.get_file(
    'nietzsche.txt',
     origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')

text = open(filepath).read().lower()

In [0]:
import numpy as np

maxlen = 60
step = 3

sentences = []
predictions = []

for i in range(0, len(text) - maxlen, step):
  sentences.append(text[i:i+maxlen])
  predictions.append(text[i+maxlen+1])

chars = sorted(list(set(text)))
char_index = dict((x, chars.index(x)) for x in chars)

x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)

for i, sentence in enumerate(sentences):
  for j, c in enumerate(sentence):
    x[i, j, char_index[c]] = 1
  y[i, char_index[predictions[i]]] = 1

In [0]:
from keras import models, layers

model = models.Sequential()
model.add(layers.LSTM(128, input_shape=(maxlen, len(chars))))
model.add(layers.Dense(len(chars), activation='softmax'))

optimizer = keras.optimizers.RMSprop(lr=0.01)
model.compile(optimizer=optimizer, 
              loss='categorical_crossentropy', 
              metrics=['acc'])

In [0]:
def sample(preds, temperature=1):
  preds = np.asarray(preds).astype('float64')
  preds = np.log(preds) / temperature
  exp_preds = np.exp(preds)
  preds = exp_preds / sum(exp_preds)
  probs = np.random.multinomial(1, preds, 1)
  return np.argmax(probs)
  

In [0]:
import sys

num_epochs = 10
max_gen_len = 400

for i in range(0, num_epochs):
  model.fit(x, y, epochs=1, batch_size=128)
  
  seed_index = np.random.randint(0, len(text) - maxlen - 1)
  seed_text = text[seed_index : seed_index + maxlen]

  temps = [0.2, 0.5, 0.7, 1]
  
  for temp in temps:
    print('------ temp: %d ------' % temp)
    gen_text = seed_text
    sys.stdout.write(gen_text)
      
    for i in range(max_gen_len):
      encoded_seed = np.zeros((1, maxlen, len(chars)))
      for i, c in enumerate(gen_text):
        encoded_seed[0, i, char_index[c]] = 1
  
      preds = model.predict(encoded_seed)[0]
      sample_index = sample(preds, temp)
      sample_char = chars[sample_index]
      sys.stdout.write(sample_char)
      gen_text += sample_char
      gen_text = gen_text[1:]