In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
with open("anna.txt", "r") as f:
    text = f.read()
    
vocab = sorted(set(text))
vocab_encode = {v:i for i,v in enumerate(vocab)}
vocab_decode = {i:v for v,i in vocab_encode.items()}
encoded = np.array([vocab_encode[c] for c in text], dtype=np.int16)

In [3]:
train_batch_size = 8
sequence_length = 128
lstm_feature_depth = 256
lstm_layer_depth = 2
num_classes = len(vocab)

save_interval = 5

In [4]:
text[:100]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [5]:
encoded[:100]

array([31, 64, 57, 72, 76, 61, 74,  1, 16,  0,  0,  0, 36, 57, 72, 72, 81,
        1, 62, 57, 69, 65, 68, 65, 61, 75,  1, 57, 74, 61,  1, 57, 68, 68,
        1, 57, 68, 65, 67, 61, 26,  1, 61, 78, 61, 74, 81,  1, 77, 70, 64,
       57, 72, 72, 81,  1, 62, 57, 69, 65, 68, 81,  1, 65, 75,  1, 77, 70,
       64, 57, 72, 72, 81,  1, 65, 70,  1, 65, 76, 75,  1, 71, 79, 70,  0,
       79, 57, 81, 13,  0,  0, 33, 78, 61, 74, 81, 76, 64, 65, 70], dtype=int16)

In [6]:
target = np.roll(encoded, -1)
target[-1] = 0

chars_per_batch = train_batch_size * sequence_length
num_batches = len(encoded)//chars_per_batch
print("number of batches : {}".format(num_batches))

# discard extra characters
encoded = encoded[:num_batches*chars_per_batch]
target = target[:num_batches*chars_per_batch]

number of batches : 1938


In [7]:
def get_batch(batch_sequences, target_batch_sequences, idx, sequence_length):
    batch = batch_sequences[:,idx*sequence_length:(idx+1)*sequence_length]
    target_batch = target_batch_sequences[:,idx*sequence_length:(idx+1)*sequence_length]
    return batch, target_batch

batch_sequences = encoded.reshape((train_batch_size, -1))
target_batch_sequences = target.reshape((train_batch_size, -1))

#batch, target_batch = get_batch(batch_sequences, target_batch_sequences, 0, sequence_length)
batch, target_batch = get_batch(batch_sequences, target_batch_sequences, 0, 32)

In [8]:
batch, target_batch

(array([[31, 64, 57, 72, 76, 61, 74,  1, 16,  0,  0,  0, 36, 57, 72, 72, 81,
          1, 62, 57, 69, 65, 68, 65, 61, 75,  1, 57, 74, 61,  1, 57],
        [ 1, 77, 72, 71, 70,  1, 64, 61, 74, 13,  0,  0, 29, 76,  1, 44, 61,
         76, 61, 74, 75, 58, 77, 74, 63, 11,  1, 57, 75,  1, 75, 71],
        [57, 76, 64, 61, 74,  1, 58, 61, 62, 71, 74, 61, 11,  1, 57, 70, 60,
          1, 70, 71, 79, 11,  1, 61, 78, 61, 74,  1, 75, 65, 70, 59],
        [79, 57, 81,  0, 79, 64, 65, 68, 61,  1, 37,  7, 69,  1, 65, 70,  1,
         76, 64, 61,  1, 75, 61, 74, 78, 65, 59, 61, 27,  1, 37, 62],
        [11,  1, 64, 57, 70, 60, 61, 60,  1, 64, 61, 74,  1, 64, 65, 75,  1,
         60, 65, 57, 74, 81, 13,  1, 36, 61,  1, 67, 70, 61, 79,  1],
        [68, 61, 80, 61, 81,  1, 29, 68, 61, 80, 57, 70, 60, 74, 71, 78, 65,
         76, 59, 64,  1, 68, 65, 75, 76, 61, 70, 61, 60,  1, 76, 71],
        [60, 11,  1, 79, 64, 65, 59, 64,  1, 76, 64, 61,  1, 64, 71, 74, 75,
         61,  1, 64, 57, 60,  1, 69, 57, 

In [9]:
def decode(predictions):
    dec = [[vocab_decode[i] for i in row] for row in predictions]
    return [''.join(list) for list in dec]

## Graph creation

In [10]:
x = tf.placeholder(tf.int32, [None, None], name="inputs")
y = tf.placeholder(tf.int32, [None, None], name="outputs")

dropout = tf.placeholder(tf.float32, name="dropout")
batch_size = tf.placeholder(tf.int32, [], name="batch_size")

In [11]:
def create_cell(lstm_feature_depth, dropout):
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(lstm_feature_depth, activation=tf.nn.tanh)
    lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=1-dropout)
    return lstm_cell

multirnn_cell = tf.nn.rnn_cell.MultiRNNCell([create_cell(lstm_feature_depth, dropout) for _ in range(lstm_layer_depth)])
initial_state = multirnn_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
rnn_out, new_state = tf.nn.dynamic_rnn(cell=multirnn_cell, inputs=tf.one_hot(x, num_classes), initial_state=initial_state)
logits = tf.layers.dense(inputs=rnn_out, units=num_classes)
softmax = tf.nn.softmax(logits)

predictions = tf.argmax(input=softmax, axis=2)

In [12]:
# training graph
y_one_hot = tf.one_hot(y, num_classes)
loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_one_hot)
loss = tf.reduce_mean(loss)

tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 5)
optimize = tf.train.AdamOptimizer(learning_rate=0.001).apply_gradients(zip(grads, tvars))

In [13]:
# training

saver = tf.train.Saver(max_to_keep=10)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    batch_start_state = sess.run(initial_state, feed_dict={batch_size:train_batch_size})
    
    counter=0
    for t in range(num_batches):
        x_batch, y_batch = get_batch(batch_sequences, target_batch_sequences, t, sequence_length)
        feed = {
            x:x_batch,
            y:y_batch,
            dropout:0.4,
            batch_size:train_batch_size,
            initial_state:batch_start_state
        }

        batch_loss, batch_start_state, _ = sess.run([loss, new_state, optimize], feed_dict=feed)
        counter += 1
        if(counter%save_interval==0):
            print("Loss : {}".format(batch_loss))
            saver.save(sess, "checkpoints/i{}.ckpt".format(counter))

Loss : 3.6990115642547607
Loss : 3.507460117340088
Loss : 3.347991943359375
Loss : 3.2357261180877686
Loss : 3.166090488433838
Loss : 3.1533255577087402
Loss : 3.154907464981079
Loss : 3.139657735824585
Loss : 3.140908718109131
Loss : 3.2264699935913086
Loss : 3.2220911979675293
Loss : 3.1838788986206055
Loss : 3.191047191619873
Loss : 3.2038235664367676
Loss : 3.2396669387817383
Loss : 3.2552709579467773
Loss : 3.1707069873809814
Loss : 3.221566677093506
Loss : 3.2221131324768066
Loss : 3.145331621170044
Loss : 3.1663479804992676
Loss : 3.1408705711364746
Loss : 3.079702854156494
Loss : 3.0622129440307617
Loss : 3.110764980316162
Loss : 3.134237289428711
Loss : 3.0624570846557617
Loss : 3.0594770908355713
Loss : 3.1090526580810547
Loss : 3.027599811553955
Loss : 3.0185232162475586
Loss : 3.040438413619995
Loss : 3.080812454223633
Loss : 2.9908738136291504
Loss : 2.9409680366516113
Loss : 3.020944595336914
Loss : 2.849097967147827
Loss : 2.9092535972595215
Loss : 2.889892101287842
Loss

KeyboardInterrupt: 

In [14]:
checkpoint = tf.train.latest_checkpoint('checkpoints')
print("latest checkpoint : {}".format(checkpoint))

latest checkpoint : checkpoints/i845.ckpt


In [15]:
# test data
test_seq, test_target = get_batch(np.expand_dims(encoded, axis=0), np.expand_dims(target, axis=0), 0, 256)

In [16]:
with tf.Session() as sess:
    saver.restore(sess, save_path=checkpoint)
    batch_start_state = sess.run(initial_state, feed_dict={batch_size:1})
    feed = {x:test_seq, 
            y:test_target,
            dropout:0,
            batch_size:1,
            initial_state:batch_start_state}
    pred = sess.run(predictions, feed_dict=feed)

INFO:tensorflow:Restoring parameters from checkpoints/i845.ckpt


In [17]:
test_seq, pred

(array([[31, 64, 57, 72, 76, 61, 74,  1, 16,  0,  0,  0, 36, 57, 72, 72, 81,
          1, 62, 57, 69, 65, 68, 65, 61, 75,  1, 57, 74, 61,  1, 57, 68, 68,
          1, 57, 68, 65, 67, 61, 26,  1, 61, 78, 61, 74, 81,  1, 77, 70, 64,
         57, 72, 72, 81,  1, 62, 57, 69, 65, 68, 81,  1, 65, 75,  1, 77, 70,
         64, 57, 72, 72, 81,  1, 65, 70,  1, 65, 76, 75,  1, 71, 79, 70,  0,
         79, 57, 81, 13,  0,  0, 33, 78, 61, 74, 81, 76, 64, 65, 70, 63,  1,
         79, 57, 75,  1, 65, 70,  1, 59, 71, 70, 62, 77, 75, 65, 71, 70,  1,
         65, 70,  1, 76, 64, 61,  1, 43, 58, 68, 71, 70, 75, 67, 81, 75,  7,
          1, 64, 71, 77, 75, 61, 13,  1, 48, 64, 61,  1, 79, 65, 62, 61,  1,
         64, 57, 60,  0, 60, 65, 75, 59, 71, 78, 61, 74, 61, 60,  1, 76, 64,
         57, 76,  1, 76, 64, 61,  1, 64, 77, 75, 58, 57, 70, 60,  1, 79, 57,
         75,  1, 59, 57, 74, 74, 81, 65, 70, 63,  1, 71, 70,  1, 57, 70,  1,
         65, 70, 76, 74, 65, 63, 77, 61,  1, 79, 65, 76, 64,  1, 57,  1, 34,

In [18]:
print(pred.shape)
print("\n-------------------------------------------------------------\n".join(decode(pred)))

(1, 256)
oaiiir toofhet e tor nend an  tnl tnlne  aner  tn enee tor ne an tn eree tn tn eof  ths  
"nere heng ths tn tomdertnn tn the snleng y  sferlsd The sash tes aontered d ahet hhe sas end ahs tonee ng tf tndtn eenetdtash hnpronse wosee aha had te r tntore  ed 


In [25]:
prime = 'happy'
gen = []
with tf.Session() as sess:
    saver.restore(sess, save_path=checkpoint)
    char_start_state = sess.run(initial_state, feed_dict={batch_size:1})
    
    for p in prime:
        feed = {
            x:np.expand_dims(np.expand_dims(vocab_encode[p], axis=0), axis=0),
            dropout:0,
            batch_size:1,
            initial_state:char_start_state
        }
        c, char_start_state = sess.run([predictions, new_state], feed_dict=feed)
        print(decode(c))
    
    for i in range(128):
        feed = {
            x:c,
            dropout:0,
            batch_size:1,
            initial_state:char_start_state
        }
        
        c, char_start_state = sess.run([predictions, new_state], feed_dict=feed)
        gen.append(c[0][0])
gen = np.expand_dims(gen,0)

INFO:tensorflow:Restoring parameters from checkpoints/i845.ckpt
['e']
['s']
['r']
['r']
[' ']


In [26]:
decode(gen)

['and the salle the salled and the salle the salled and the salle the salled and the salle the salled and the salle the salled and']