In [None]:
"""
Example script to generate text from a corpus of text
--By word--
It is recommended to run this script on GPU, as recurrent
networks are quite computationally intensive.
Based on
https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py
20 epochs should be enough to get decent results.
Uses data generator to avoid loading all the test set into memory.
Saves the weights and model every epoch.
"""

from __future__ import print_function
from keras.callbacks import LambdaCallback, ModelCheckpoint, EarlyStopping
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, LSTM, Bidirectional
import numpy as np
import random
import sys
import io
import os
import re
import itertools
from collections import Counter

In [2]:
# Parameters
seq_len = 10 # includes next word.
step = 2
word_freq_threshold = 2

# NN parameters
batch_size = 128

In [3]:
# Read file
file = open('data/speeches.txt', 'r',encoding='utf-8-sig') 
speeches = file.read()

# Remove text between brackets, such as (inaudible) or (laughter)
speeches = re.sub("[\(\[].*?[\)\]]", "", speeches)
# Remove the speech introductions
speeches = re.sub(r'SPEECH.+?\n', '', speeches)
# Replace multiple periods with a single one.
speeches = re.sub('\.+','. ',speeches)
# different uses of this character
speeches = re.sub('\'','’',speeches)
# Replace new lines with spaces
speeches = re.sub('\n',' ', speeches)

# Treat the following interpunction characters as separate words, so we can generate them.
speeches = re.sub('\. ',' . ', speeches)
speeches = re.sub(', ',' , ', speeches)
speeches = re.sub('\? ',' ? ', speeches)
speeches = re.sub('! ',' ! ', speeches)
speeches = re.sub('; ',' ; ', speeches)
punc = '.,?!;'

# Keep only this set of characters, replace multiple whitespace with single, and convert to lower case.
speeches = re.sub('[^0-9a-zA-Z\.,\?!;’]+', ' ', speeches)
speeches = re.sub('\s+',' ', speeches)
speeches = speeches.lower()

In [4]:
words = speeches.split(' ')
counts = Counter(words)

# Identify the foribdden words, i.e. words that occur less than a certain threshold.
forbidden_words=[]
for word in list(counts):
    if counts[word] < word_freq_threshold:
        forbidden_words.append(word)
        
# Find the indices of the forbidden words        
forbidden_indices = [[i for i, x in enumerate(words) if x == f_word] for f_word in forbidden_words]
forbidden_indices = list(itertools.chain.from_iterable(forbidden_indices))
forbidden_indices.sort()

# Now, create the ranges of words for the sentences. If a range contains one of the indices in
# forbidden_indices, we omit it from the data.
sentence_ranges = [range(i,i+seq_len-1) for i in range(0,len(words)-seq_len,step)]
s = 0
f = 0
while (s < len(sentence_ranges)) & (f < len(forbidden_indices)):
    if forbidden_indices[f] in sentence_ranges[s]:
        sentence_ranges.pop(s)
    else:
        if max(sentence_ranges[s])>forbidden_indices[f]:
            f+=1
        else:
            s+=1
sentences = [[words[y] for y in x] for x in sentence_ranges]

In [5]:
print('Original sentences: ' + str(len([range(i,i+seq_len-1) for i in range(0,len(words)-seq_len,step)])))
print('Truncated sentences: '+ str(len(sentences)))

Original sentences: 92822
Truncated sentences: 84456


In [6]:
sentences

[['', '.', 'thank', 'you', 'so', 'much', '.', 'that’s', 'so'],
 ['thank', 'you', 'so', 'much', '.', 'that’s', 'so', 'nice', '.'],
 ['so', 'much', '.', 'that’s', 'so', 'nice', '.', 'isn’t', 'he'],
 ['.', 'that’s', 'so', 'nice', '.', 'isn’t', 'he', 'a', 'great'],
 ['so', 'nice', '.', 'isn’t', 'he', 'a', 'great', 'guy', '.'],
 ['.', 'isn’t', 'he', 'a', 'great', 'guy', '.', 'he', 'doesn’t'],
 ['he', 'a', 'great', 'guy', '.', 'he', 'doesn’t', 'get', 'a'],
 ['great', 'guy', '.', 'he', 'doesn’t', 'get', 'a', 'fair', 'press'],
 ['.', 'he', 'doesn’t', 'get', 'a', 'fair', 'press', ';', 'he'],
 ['doesn’t', 'get', 'a', 'fair', 'press', ';', 'he', 'doesn’t', 'get'],
 ['a', 'fair', 'press', ';', 'he', 'doesn’t', 'get', 'it', '.'],
 ['press', ';', 'he', 'doesn’t', 'get', 'it', '.', 'it’s', 'just'],
 ['he', 'doesn’t', 'get', 'it', '.', 'it’s', 'just', 'not', 'fair'],
 ['get', 'it', '.', 'it’s', 'just', 'not', 'fair', '.', 'and'],
 ['.', 'it’s', 'just', 'not', 'fair', '.', 'and', 'i', 'have'],
 ['just'

In [15]:
check = sum([x in list(itertools.chain.from_iterable(sentences)) for x in forbidden_words])
print('Forbidden words still in our text, should be zero: ' + str(check))

Forbidden words still in our text, should be zero: 0


In [None]:
# Train test split
random.shuffle(sentences)
X = [x[:-1] for x in sentences]
Y = [x[-1:] for x in sentences]
train_split = int(0.98*len(X))
X_train, Y_train = X[:train_split], Y[:train_split]
X_test, Y_test = X[train_split:], Y[train_split:]

In [None]:
# Data generator for fit and evaluate
# modified from https://github.com/enriqueav/lstm_lyrics/blob/master/lstm_train.py
def generator(sentence_list, next_word_list, batch_size):
    index = 0
    while True:
        x = np.zeros((batch_size, SEQUENCE_LEN, len(words)), dtype=np.bool)
        y = np.zeros((batch_size, len(words)), dtype=np.bool)
        for i in range(batch_size):
            for t, w in enumerate(sentence_list[index % len(sentence_list)]):
                x[i, t, word_indices[w]] = 1
            y[i, word_indices[next_word_list[index % len(sentence_list)]]] = 1
            index = index + 1
        yield x, y

In [38]:
# Function from keras-team/keras/blob/master/examples/lstm_text_generation.py
def sample(preds, temperature=1.0):
    # helper function to sample an index from a probability array
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

# Function from keras-team/keras/blob/master/examples/lstm_text_generation.py
def on_epoch_end(epoch, logs):
    # Function invoked at end of each epoch. Prints generated text.
    print()
    print('----- Generating text after Epoch: %d' % epoch)

    start_index = random.randint(0, len(text) - maxlen - 1)
    for diversity in [0.2, 0.5, 1.0, 1.2]:
        print('----- diversity:', diversity)

        generated = ''
        sentence = text[start_index: start_index + maxlen]
        generated += sentence
        print('----- Generating with seed: "' + sentence + '"')
        sys.stdout.write(generated)

        for i in range(400):
            x_pred = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(sentence):
                x_pred[0, t, char_indices[char]] = 1.

            preds = model.predict(x_pred, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = indices_char[next_index]

            generated += next_char
            sentence = sentence[1:] + next_char

            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()