In [34]:
import keras
import numpy as np


path = keras.utils.get_file(
    'nietzsche.txt',
    origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt'
    )

with open(path, mode='tr', encoding='utf-8') as file:
    text = file.read().lower()
print('Corpus length:', len(text))

Corpus length: 600893


In [35]:
N_chars = 60
N_step = 3


sentences = []
next_chars = []
for i in range(0, len(text) - N_chars, N_step):
    sentences.append(text[i: i + N_chars])
    next_chars.append(text[i + N_chars])

chars = sorted(list(set(text)))
char_indices = dict((char, chars.index(char)) for char in chars)

X = np.zeros((len(sentences), N_chars, len(chars)))
y = np.zeros((len(sentences), len(chars)))
for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        X[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1

In [36]:
from keras import layers


model = keras.models.Sequential()
model.add(layers.LSTM(128, input_shape=(N_chars, len(chars))))
model.add(layers.Dense(len(chars), activation='softmax'))

model.summary()

optimizer = keras.optimizers.RMSprop(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm_2 (LSTM)               (None, 128)               95232     
                                                                 
 dense_2 (Dense)             (None, 57)                7353      
                                                                 
Total params: 102585 (400.72 KB)
Trainable params: 102585 (400.72 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [48]:
def sample_index(prob, temperature=1.0):
    prob = np.exp(np.log((prob.astype(np.float64) + 1e-12) / temperature))
    prob = prob / np.sum(prob)
    return np.argmax(np.random.multinomial(1, prob, 1))

In [None]:
import random
import sys


for epoch in range(1, 60):
    print('epoch', epoch)
    model.fit(X, y, batch_size=128, epochs=1)
    start_index = random.randint(0, len(text) - N_chars - 1)
    generated_text = text[start_index: start_index + N_chars]
    print('--- Generating with seed: "' + generated_text + '"')
    for temperature in [0.2, 0.5, 1.0, 1.2]:
        print('------ temperature:', temperature)
        sys.stdout.write(generated_text)
        for i in range(400):
            sampled = np.zeros((1, N_chars, len(chars)))
            for t, char in enumerate(generated_text):
                sampled[0, t, char_indices[char]] = 1.
                preds = model.predict(sampled, verbose=0)[0]
                next_index = sample(preds, temperature)
                next_char = chars[next_index]
                generated_text += next_char
                generated_text = generated_text[1:]
                sys.stdout.write(next_char)

In [38]:
model.fit(X, y, batch_size=128, epochs=40)

Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


<keras.src.callbacks.History at 0x27d4f20c390>

In [49]:
import sys
import numpy as np


temperature = 0.5
N_tokens = 400


start_idx = np.random.randint(0, len(text) - N_chars - 1)
generated_text = text[start_idx: start_idx + N_chars]
sys.stdout.write(generated_text)

for _ in range(N_tokens):
    X_sampled = np.zeros((1, N_chars, len(chars)))
    for char_idx, char in enumerate(generated_text):
        X_sampled[0, t, char_indices[char]] = 1.
        prob = model.predict(X_sampled, verbose=0)[0]
        next_char = chars[sample_index(prob, temperature)]
        generated_text += next_char
        generated_text = generated_text[1:]
        sys.stdout.write(next_char)

 their
best soldiers, and likewise, alas! their first and pram x_'________äääääääääääääääääääääääääää]]]ää]]äää]]ääääääänh__3(((((hhhh((h(hh(hh((h((((((((h((((hh55565_55_55_55555__eeä'e[[[_[[8_[[[8_[[['[_[8[8_8__[[[8[[[[[8[[ä8äääääää_ää_ääävvvs?=32121]]]1]]]]]1]

KeyboardInterrupt: 