In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time

np.set_printoptions(precision=4, linewidth=200)

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
print(tf.__version__)

1.3.0


In [3]:
from utils.reader import ptb_raw_data
from utils.batcher import ptb_batcher
from utils.conditional_scope import cond_name_scope, cond_variable_scope
from utils.unrolled_rnn import make_rnn_variables
from utils.unrolled_rnn import make_rnn_outputs
from utils.unrolled_rnn import make_summary_nodes
from utils.unrolled_rnn import make_placeholders
from utils.unrolled_rnn import make_train_op
from utils.batcher import generate_epoch

In [4]:
X_train, X_val, X_test, vocab_size = ptb_raw_data('bigdata/simple-examples/data')

In [8]:
EMBEDDING_SIZE=64
HIDDEN_SIZE=256
BATCH_SIZE=16
NUM_STEPS=16
NUM_EPOCHS_INIT_LR=5
NUM_EPOCHS_TOTAL=10
INITIAL_LR=1.
LR_DECAY_RATE=0.5

In [18]:
tf.reset_default_graph()
placeholders = make_placeholders(
    batch_size=BATCH_SIZE,
    num_steps=NUM_STEPS
)
rnn_vars = make_rnn_variables(
    vocab_size=vocab_size,
    embedding_size=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
)
rnn_outputs = make_rnn_outputs(
    input_sequence=placeholders['inputs'],
    vocab_size=vocab_size,
    hidden_size=HIDDEN_SIZE,
    batch_size=BATCH_SIZE,
    num_steps=NUM_STEPS,
    rnn_variables=rnn_vars
)
summary_nodes = make_summary_nodes(
    targets=placeholders['targets'],
    logits=rnn_outputs['logits'],
)
train_op = make_train_op(
    summary_nodes['loss'],
    placeholders['learning_rate'],
)

In [21]:
training_outputs = {**summary_nodes, 'train_op': train_op}
with tf.Session() as sess:

    # Bookkeeping
    run_id = time.time()
    writer = tf.summary.FileWriter('logs/{0}'.format(run_id), sess.graph)
    coord = tf.train.Coordinator()
    sess.run(tf.global_variables_initializer())
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    learning_rate = INITIAL_LR
    for i in range(NUM_EPOCHS_TOTAL):
        if i >= NUM_EPOCHS_INIT_LR:
            learning_rate *= LR_DECAY_RATE
        for batch_idx, (inputs, targets) in enumerate(generate_epoch(X_train, BATCH_SIZE, NUM_STEPS)):
            outputs = sess.run(
                training_outputs,
                feed_dict={
                    placeholders['inputs']: inputs,
                    placeholders['targets']: targets,
                    placeholders['learning_rate']: learning_rate
                }
            )
            if (batch_idx % 128 == 127):
                print('step: {0}    loss: {1}    correct words: {2}'.format(
                    batch_idx+1,
                    outputs['loss'],
                    outputs['num_correct_predictions']
                ))
                
        total_loss, total_steps = 0, 0
        for inputs, targets in generate_epoch(X_val, BATCH_SIZE, NUM_STEPS):
            outputs = sess.run(
                summary_nodes,
                feed_dict={
                    placeholders['inputs']: inputs,
                    placeholders['targets']: targets
                },
            )
            total_loss += outputs['loss']
            total_steps += NUM_STEPS
        
        print('validation perplexity:', np.exp(total_loss / total_steps)) 

    # Bookkeeping        
    writer.close()
    coord.request_stop()
    coord.join(threads)

step: 128    loss: 111.75    correct words: 22
step: 256    loss: 99.5    correct words: 23
step: 384    loss: 96.375    correct words: 40
step: 512    loss: 97.125    correct words: 30
step: 640    loss: 94.75    correct words: 35
step: 768    loss: 90.8125    correct words: 30
step: 896    loss: 90.25    correct words: 42
step: 1024    loss: 91.375    correct words: 44
step: 1152    loss: 91.625    correct words: 39
step: 1280    loss: 93.5625    correct words: 36
step: 1408    loss: 88.125    correct words: 46
step: 1536    loss: 91.9375    correct words: 40
step: 1664    loss: 86.5    correct words: 46
step: 1792    loss: 89.5    correct words: 42
step: 1920    loss: 80.25    correct words: 54
step: 2048    loss: 85.875    correct words: 45
step: 2176    loss: 82.5    correct words: 56
step: 2304    loss: 90.6875    correct words: 38
step: 2432    loss: 81.875    correct words: 47
step: 2560    loss: 87.125    correct words: 33
step: 2688    loss: 85.75    correct words: 45
step: 2

step: 256    loss: 81.0    correct words: 41
step: 384    loss: 84.0625    correct words: 38
step: 512    loss: 83.6875    correct words: 45
step: 640    loss: 79.375    correct words: 57
step: 768    loss: 81.9375    correct words: 46
step: 896    loss: 78.875    correct words: 65
step: 1024    loss: 79.375    correct words: 62
step: 1152    loss: 79.625    correct words: 51
step: 1280    loss: 82.75    correct words: 40
step: 1408    loss: 78.125    correct words: 50
step: 1536    loss: 85.625    correct words: 47
step: 1664    loss: 77.625    correct words: 46
step: 1792    loss: 81.9375    correct words: 46
step: 1920    loss: 74.125    correct words: 53
step: 2048    loss: 80.375    correct words: 46
step: 2176    loss: 76.25    correct words: 60
step: 2304    loss: 85.3125    correct words: 41
step: 2432    loss: 75.0    correct words: 60
step: 2560    loss: 81.0    correct words: 35
step: 2688    loss: 82.375    correct words: 49
step: 2816    loss: 86.6875    correct words: 40
