In [1]:
import tensorflow as tf
import numpy as np
from time import time

### Process Input

* Read input text.
* Split intro training and validation.
* Split each into chunks of `series_size` characters.
* Batch up int random batches of input.

In [2]:
encoding_depth = 128 # nr of chars in ascii

with open("inputs/shakespeare.txt") as file:
    shakespeare = file.read().encode("ascii")
print("Shakespeare writings: {:.2f}MB".format(len(shakespeare)/(1<<20)))

Shakespeare writings: 4.36MB


In [3]:
train_frac = 0.95
valid_idx = int(len(shakespeare)*train_frac)
train, valid = shakespeare[:valid_idx], shakespeare[valid_idx:]

In [4]:
def make_io(chars):
    chars = list(map(int, chars))
    x = chars[:-1]
    y = chars[1:] # predict next character
    
    return x, y

In [5]:
series_size = 64 # chunk of characters to process at a time
def chunk_to_series(io):
    x, y = io
    return [
        (x[i:i+series_size], y[i:i+series_size])
        for i in range(0, len(x)-series_size, series_size)
    ]

In [6]:
series = chunk_to_series(make_io(train))

In [7]:
transpose = lambda l: list(zip(*l))

In [8]:
batch_size = 4

def chunk_to_batches(series):
    """Return batches of randomly shuffled series."""

    batches_xy = [
        series[i:i+batch_size]
        
        for i in range(0, len(series)-batch_size, batch_size)]
    
    return [transpose(batch_xy) for batch_xy in batches_xy]

batches = chunk_to_batches(series)

In [9]:
def epoch(text):
    return [transpose(chunk_to_series(make_io(text)))]

def random_batched_epoch(text):
    series = chunk_to_series(make_io(text))
    
    series = series.copy()
    np.random.shuffle(series)
    
    return chunk_to_batches(series)

In [10]:
valid_x, valid_y = epoch(valid)[0]

### Define the RNN

Stack a few LSTM layers, then one more fully connected layer on top to the output. Hot-encode the input character and predict the next character in the series.

In [11]:
# Size of each lstm layer.
lstm_layers = [128, 64, 64]

# Ignore cost of first part of char series because the lstm memory is zero.
ignore_front = 16

# Reset batch size and series size so we accept any form of input.
inputs_shape = [None, None] # [batch_size, series_size]

graph = tf.Graph()

with graph.as_default():
    keep_prob = tf.placeholder(tf.float32, name="keep_prop")

    with tf.name_scope("inputs"):
        x = tf.placeholder(tf.int32, inputs_shape, name="x")
        hot_x = tf.one_hot(x, encoding_depth)

    with tf.name_scope("targets"):
        y = tf.placeholder(tf.int32, inputs_shape, name="y")
        hot_y = tf.one_hot(y, encoding_depth)


    with tf.name_scope("RNN"):
        def drop_lstm(lstm_size):
            lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
            return tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)

        rnn = tf.contrib.rnn.MultiRNNCell([drop_lstm(s) for s in lstm_layers])

        rnn_out, final_state = tf.nn.dynamic_rnn(rnn, hot_x, dtype=tf.float32)

    with tf.name_scope("fully_connected"):
        weight = tf.Variable(tf.truncated_normal([lstm_layers[-1], encoding_depth], stddev=0.1), name="weight")
        bias = tf.Variable(tf.constant(0.0, shape=[encoding_depth]), name="bias")

        conn_out = tf.tensordot(rnn_out, weight, axes=[[2], [0]]) + bias

    with tf.name_scope("output"):
        prediction = tf.nn.softmax(conn_out, dim=-1)

        # Pick best char
        predicted_y = tf.argmax(prediction, axis=-1)

    with tf.name_scope("cost"):
        char_cross_entropy = -tf.reduce_sum(hot_y * tf.log(prediction), axis=-1)

        batch_cross_entropy = tf.reduce_mean(char_cross_entropy[..., ignore_front:])


    with tf.name_scope("optimizer"):
        optimizer = tf.train.AdamOptimizer().minimize(batch_cross_entropy)


    # Summarize
    tf.summary.histogram("expected_char", y)

    tf.summary.histogram("prediction", prediction)
    tf.summary.histogram('max_prediction', tf.reduce_max(prediction, axis=-1))

    tf.summary.histogram("weight", weight)
    tf.summary.histogram("bias", bias)

    tf.summary.histogram("predicted_char", predicted_y)
    tf.summary.histogram("char_cross_entropy", char_cross_entropy)

    tf.summary.scalar("train_cross_entropy", batch_cross_entropy)

    summary = tf.summary.merge_all()

    validate_summary = tf.summary.scalar("validate_cross_entropy", batch_cross_entropy)

    # Details
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

In [None]:
run_id = 0

In [None]:
epochs = 20

run_id += 1

restore = None
with tf.Session(graph=graph) as sess:
    
    if restore:
        saver.restore(sess, restore)
    else:
        sess.run(init)
    
    writer = tf.summary.FileWriter("./logs/{}".format(run_id), sess.graph)

    i = 0
    for epoch in range(epochs):
        for in_x, out_y in random_batched_epoch(train):
            
            sess.run(optimizer, feed_dict={x: in_x, y: out_y, keep_prob: 0.9})


            if i % 200 == 0:
                s, train_cost, p_y = sess.run([summary, batch_cross_entropy, predicted_y], feed_dict={
                    x: in_x, y: out_y, keep_prob: 1.0})
                
                vs, valid_cost = sess.run([validate_summary, batch_cross_entropy], feed_dict={
                    x: valid_x, y: valid_y, keep_prob: 1.0})

                print("Batch {:>2}".format(i))
                print("  train_cross_entropy = {:>10.5f}".format(train_cost))
                print("  valid_cross_entropy = {:>10.5f}".format(valid_cost))
                print("  ")
                print("  Expected : {}".format(bytes(out_y[0][-50:]).decode("ascii").replace("\n", "|")))
                print("  Predicted: {}".format(bytes(p_y[0][-50:]).decode("ascii").replace("\n", "|")))
                print("---")
                
                writer.add_summary(s, i)
                writer.add_summary(vs, i)
                
            
            i += 1
            
            if i % 5000 == 0:
                save_path = saver.save(sess, "./saves/run_{}_batch_{}_epoch_{}.ckpt".format(run_id, i, epoch))
                print("")
                print("Model saved in file: {}.".format(save_path))
                print("===")
                
                restore = save_path

Batch  0
  train_cross_entropy =    4.84508
  valid_cross_entropy =    4.84666
  
  Expected : e, and you weigh this well;|Therefore still bear t
  Predicted:                                                                    e       e       U       U       U       U       U                                                               e       e       U       U       U       U       e       e       e       e       e                                                                                                                
---
Batch 200
  train_cross_entropy =    3.22829
  valid_cross_entropy =    3.31272
  
  Expected : herein thy counsel and consent is wanting.|Richard
  Predicted:                                                                                                                                                                                                                                                                                     