In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time

np.set_printoptions(precision=4, linewidth=200)

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
print(tf.__version__)

1.3.0


In [3]:
from utils.reader import ptb_raw_data
from utils.batcher import ptb_batcher
from utils.conditional_scope import cond_name_scope, cond_variable_scope
from utils.unrolled_rnn import make_rnn_variables
from utils.unrolled_rnn import make_rnn_outputs
from utils.unrolled_rnn import make_summary_nodes
from utils.unrolled_rnn import make_placeholders
from utils.unrolled_rnn import make_train_op
from utils.batcher import generate_epoch

In [4]:
X_train, X_val, X_test, vocab_size = ptb_raw_data('bigdata/simple-examples/data')

In [5]:
EMBEDDING_SIZE=64
HIDDEN_SIZE=256
BATCH_SIZE=32
NUM_STEPS=16
NUM_EPOCHS_INIT_LR=3
NUM_EPOCHS_TOTAL=8
INITIAL_LR=5e0
LR_DECAY_RATE=0.75
MAX_NORM=0.5

In [6]:
tf.reset_default_graph()
placeholders = make_placeholders(
    batch_size=BATCH_SIZE,
    num_steps=NUM_STEPS
)
rnn_vars = make_rnn_variables(
    vocab_size=vocab_size,
    embedding_size=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
    initializer_scale=0.2
)
rnn_outputs = make_rnn_outputs(
    input_sequence=placeholders['inputs'],
    vocab_size=vocab_size,
    hidden_size=HIDDEN_SIZE,
    batch_size=BATCH_SIZE,
    num_steps=NUM_STEPS,
    rnn_variables=rnn_vars
)
summary_nodes = make_summary_nodes(
    targets=placeholders['targets'],
    logits=rnn_outputs['logits'],
)
training_nodes = make_train_op(
    summary_nodes['loss'],
    placeholders['learning_rate'],
    placeholders['max_norm'],
)

In [7]:
training_outputs = {**summary_nodes, **training_nodes}
with tf.Session() as sess:

    # Bookkeeping
    run_id = time.time()
    writer = tf.summary.FileWriter('logs/{0}'.format(run_id), sess.graph)
    coord = tf.train.Coordinator()
    sess.run(tf.global_variables_initializer())
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    learning_rate = INITIAL_LR
    max_norm = MAX_NORM
    for i in range(NUM_EPOCHS_TOTAL):
        if i >= NUM_EPOCHS_INIT_LR:
            learning_rate *= LR_DECAY_RATE
        for batch_idx, (inputs, targets) in enumerate(generate_epoch(X_train, BATCH_SIZE, NUM_STEPS)):
            outputs = sess.run(
                training_outputs,
                feed_dict={
                    placeholders['inputs']: inputs,
                    placeholders['targets']: targets,
                    placeholders['learning_rate']: learning_rate,
                    placeholders['max_norm']: max_norm,
                }
            )
            if (batch_idx % 64 == 63):
                print('step: {0}    loss: {1}    gradient norm: {2}     correct words: {3}'.format(
                    batch_idx+1,
                    outputs['loss'],
                    outputs['gradient_global_norm'],
                    outputs['num_correct_predictions'],
                ))
                
        total_loss, total_batches = 0, 0
        for inputs, targets in generate_epoch(X_val, BATCH_SIZE, NUM_STEPS):
            outputs = sess.run(
                summary_nodes,
                feed_dict={
                    placeholders['inputs']: inputs,
                    placeholders['targets']: targets
                },
            )
            total_loss += outputs['loss']
            total_batches += 1
        print('validation perplexity:', np.exp(total_loss / total_batches))
        total_loss, total_batches = 0, 0
        for inputs, targets in generate_epoch(X_test, BATCH_SIZE, NUM_STEPS):
            outputs = sess.run(
                summary_nodes,
                feed_dict={
                    placeholders['inputs']: inputs,
                    placeholders['targets']: targets
                },
            )
            total_loss += outputs['loss']
            total_batches += 1
        print('test perplexity:', np.exp(total_loss / total_batches))

    # Bookkeeping        
    writer.close()
    coord.request_stop()
    coord.join(threads)

step: 64    loss: 7.442488670349121    gradient norm: 0.6903454065322876     correct words: 38
step: 128    loss: 8.22467041015625    gradient norm: 3.464698314666748     correct words: 47
step: 192    loss: 7.208061695098877    gradient norm: 2.0766000747680664     correct words: 49
step: 256    loss: 6.89780855178833    gradient norm: 0.7592272758483887     correct words: 36
step: 320    loss: 6.7141642570495605    gradient norm: 0.8033626675605774     correct words: 30
step: 384    loss: 6.479608535766602    gradient norm: 0.5325131416320801     correct words: 46
step: 448    loss: 7.2704291343688965    gradient norm: 2.3638126850128174     correct words: 34
step: 512    loss: 6.724634170532227    gradient norm: 1.2482047080993652     correct words: 53
step: 576    loss: 6.448917865753174    gradient norm: 1.2443323135375977     correct words: 66
step: 640    loss: 6.110818862915039    gradient norm: 0.7929134368896484     correct words: 59
step: 704    loss: 6.256288051605225    gr

validation perplexity: 241.55689653
test perplexity: 236.897568478
step: 64    loss: 5.1805500984191895    gradient norm: 0.49780505895614624     correct words: 106
step: 128    loss: 5.205382347106934    gradient norm: 0.4922258257865906     correct words: 92
step: 192    loss: 5.188063621520996    gradient norm: 0.7560275197029114     correct words: 99
step: 256    loss: 5.139129161834717    gradient norm: 0.569052517414093     correct words: 93
step: 320    loss: 4.870426654815674    gradient norm: 0.57475346326828     correct words: 112
step: 384    loss: 4.969931602478027    gradient norm: 0.5993033051490784     correct words: 104
step: 448    loss: 5.3567914962768555    gradient norm: 0.9596911668777466     correct words: 89
step: 512    loss: 5.09026575088501    gradient norm: 0.5136498212814331     correct words: 101
step: 576    loss: 5.0604681968688965    gradient norm: 0.6595025658607483     correct words: 98
step: 640    loss: 4.831742286682129    gradient norm: 0.719275295

step: 1792    loss: 4.583516597747803    gradient norm: 0.608253002166748     correct words: 113
validation perplexity: 149.82412332
test perplexity: 142.552388792
step: 64    loss: 4.627676486968994    gradient norm: 0.5337792634963989     correct words: 126
step: 128    loss: 4.83149528503418    gradient norm: 0.6529021263122559     correct words: 119
step: 192    loss: 4.6521220207214355    gradient norm: 0.5601725578308105     correct words: 121
step: 256    loss: 4.542318820953369    gradient norm: 0.5724175572395325     correct words: 123
step: 320    loss: 4.438754558563232    gradient norm: 0.5586279630661011     correct words: 140
step: 384    loss: 4.4756855964660645    gradient norm: 0.5435066819190979     correct words: 133
step: 448    loss: 4.813286304473877    gradient norm: 0.6050848364830017     correct words: 105
step: 512    loss: 4.719724655151367    gradient norm: 0.5664400458335876     correct words: 115
step: 576    loss: 4.490606784820557    gradient norm: 0.552