In [1]:
import tensorflow as tf
from model.seq2seq import basic_seq2seq
import numpy as np
from utils.text_processing import load_dict_from_vocab_file

vocab_file = './data/character_inventory_unk.txt'
traindb_file = './data/training.npz'
testdb_file = './data/testing.npz'
checkpoint_file = './tfmodel/gru_enc/model_%d.tfmodel'
log_dir = './tb'
log_interval = 10

  from ._conv import register_converters as _register_converters


In [2]:
lr = 0.001
l2reg = 0.01
keep_prob=1.0
batch_size_val = 64
vocab = load_dict_from_vocab_file(vocab_file)
vocab_size = len(vocab)
lstm_dim = 500
n_epochs = 100

In [3]:
with tf.name_scope("placeholders"):
    encoder_in = tf.placeholder(tf.int32, [None, None])
    encoder_lens = tf.placeholder(tf.int32, [None])
    batch_size = tf.placeholder(tf.int32)
    
    decoder_in = tf.placeholder(tf.int32, [None, None])
    decoder_lens = tf.placeholder(tf.int32, [None])
    labels = tf.placeholder(tf.int32, [None, None])

In [4]:
with tf.name_scope("model"):
    logits, _ = basic_seq2seq(encoder_in, encoder_lens, decoder_in, decoder_lens,
                                          vocab_size=vocab_size, batch_size=batch_size, lstm_type="gru",
                                          lstm_dim=lstm_dim, keep_prob=keep_prob, max_iterations=101)

In [5]:
with tf.name_scope("loss"):
    #labels_flat = tf.reshape(labels, [-1])
    #logits = tf.reshape(logits, [-1, vocab_size])
    crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
    
    train_loss = (tf.reduce_sum(crossent) / tf.cast(batch_size, tf.float32))
    
    reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    reg_loss = l2reg * sum(reg_losses)
    tv = tf.trainable_variables()
    reg_losses.extend([l2reg * tf.nn.l2_loss(v) for v in tv])

    loss = train_loss + reg_loss
    
    with tf.name_scope("logging"):
        tf.summary.scalar("train_loss", train_loss)

In [6]:
with tf.name_scope("optimizer"):
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(lr)
        global_step = tf.Variable(0, trainable=False)
        #learning_rate = tf.train.exponential_decay(lr, global_step,
        #                                   10000, 0.96, staircase=True)
        #optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        gvs = optimizer.compute_gradients(loss)
        capped_gvs = [(tf.clip_by_norm(grad, 5.), var) for grad, var in gvs]
        train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step)
    
        with tf.name_scope("logging"):
            for grad, var in capped_gvs:
                tf.summary.histogram(var.name + "_grads", grad)

INFO:tensorflow:Summary name gru/rnn/gru_cell/gates/kernel:0_grads is illegal; using gru/rnn/gru_cell/gates/kernel_0_grads instead.
INFO:tensorflow:Summary name gru/rnn/gru_cell/gates/bias:0_grads is illegal; using gru/rnn/gru_cell/gates/bias_0_grads instead.
INFO:tensorflow:Summary name gru/rnn/gru_cell/candidate/kernel:0_grads is illegal; using gru/rnn/gru_cell/candidate/kernel_0_grads instead.
INFO:tensorflow:Summary name gru/rnn/gru_cell/candidate/bias:0_grads is illegal; using gru/rnn/gru_cell/candidate/bias_0_grads instead.
INFO:tensorflow:Summary name decoder/conditional_gru_cell/gates/kernel:0_grads is illegal; using decoder/conditional_gru_cell/gates/kernel_0_grads instead.
INFO:tensorflow:Summary name decoder/conditional_gru_cell/candidate/kernel:0_grads is illegal; using decoder/conditional_gru_cell/candidate/kernel_0_grads instead.
INFO:tensorflow:Summary name decoder/dense/kernel:0_grads is illegal; using decoder/dense/kernel_0_grads instead.


In [7]:
with tf.name_scope("logging"):
    valid_loss_ph = tf.placeholder(tf.float32, name="validation_loss")
    
    tf.summary.scalar("Valid_loss", valid_loss_ph)

    #Add histograms for trainable variables
    for v in tf.trainable_variables():
        tf.summary.histogram(v.name, v)    

    log_op = tf.summary.merge_all()

writer = tf.summary.FileWriter(log_dir, graph=tf.get_default_graph())

INFO:tensorflow:Summary name gru/rnn/gru_cell/gates/kernel:0 is illegal; using gru/rnn/gru_cell/gates/kernel_0 instead.
INFO:tensorflow:Summary name gru/rnn/gru_cell/gates/bias:0 is illegal; using gru/rnn/gru_cell/gates/bias_0 instead.
INFO:tensorflow:Summary name gru/rnn/gru_cell/candidate/kernel:0 is illegal; using gru/rnn/gru_cell/candidate/kernel_0 instead.
INFO:tensorflow:Summary name gru/rnn/gru_cell/candidate/bias:0 is illegal; using gru/rnn/gru_cell/candidate/bias_0 instead.
INFO:tensorflow:Summary name decoder/conditional_gru_cell/gates/kernel:0 is illegal; using decoder/conditional_gru_cell/gates/kernel_0 instead.
INFO:tensorflow:Summary name decoder/conditional_gru_cell/candidate/kernel:0 is illegal; using decoder/conditional_gru_cell/candidate/kernel_0 instead.
INFO:tensorflow:Summary name decoder/dense/kernel:0 is illegal; using decoder/dense/kernel_0 instead.


In [None]:
saver = tf.train.Saver()

In [None]:
data = np.load(traindb_file)
encoder_in_batch = data['enc_in']
encoder_len_batch = data['enc_lens']
decoder_in_batch = data['dec_in']
decoder_len_batch = data['dec_lens']
labels_batch = data['labels']

valid_data = np.load(testdb_file)
valid_encoder_in_batch = data['enc_in']
valid_encoder_len_batch = data['enc_lens']
valid_decoder_in_batch = data['dec_in']
valid_decoder_len_batch = data['dec_lens']
valid_labels_batch = data['labels']

n_examples = labels_batch.shape[0]
idx = np.arange(n_examples)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
i = 0
k = 0
min_loss_val = 10000.0
while i < n_epochs:
    print("EPOCH %d"%i)
    j = 0
    np.random.shuffle(idx)
    while j < n_examples:
        curr = idx[j:j+batch_size_val]
        batch_size_curr = len(curr)
        if k % log_interval == 0:
            #Get the validation loss
            valid_loss_val, logits_val = sess.run([loss, logits], feed_dict={encoder_in: valid_encoder_in_batch,
                                                       encoder_lens: valid_encoder_len_batch,
                                                       decoder_in: valid_decoder_in_batch,
                                                       decoder_lens: valid_decoder_len_batch,
                                                       labels: valid_labels_batch,
                                                       batch_size: valid_labels_batch.shape[0]})
            
            
            summary, _, loss_val, logits_val = sess.run([log_op, train_op, loss, logits], feed_dict={encoder_in: encoder_in_batch[curr],
                                                       encoder_lens: encoder_len_batch[curr],
                                                       decoder_in: decoder_in_batch[curr],
                                                       decoder_lens: decoder_len_batch[curr],
                                                       labels: labels_batch[curr],
                                                       batch_size: batch_size_curr,
                                                       valid_loss_ph: valid_loss_val})
            writer.add_summary(summary, k)
            
            if valid_loss_val < min_loss_val:
                print(valid_loss_val)
                min_loss_val = valid_loss_val
                saver.save(sess, checkpoint_file % k)
        else:
            _, loss_val, logits_val = sess.run([train_op, loss, logits], feed_dict={encoder_in: encoder_in_batch[curr],
                                                       encoder_lens: encoder_len_batch[curr],
                                                       decoder_in: decoder_in_batch[curr],
                                                       decoder_lens: decoder_len_batch[curr],
                                                       labels: labels_batch[curr],
                                                       batch_size: batch_size_curr})
        
        
        j += batch_size_val
        k += 1
    i += 1

EPOCH 0
180.83554
180.83554
160.38715
154.25702
143.35489
143.82669
141.94229
141.88965
141.0227
140.97894
139.69888
139.78992
139.78992
139.17722
138.33409
137.90045
138.29399
138.1435
137.08011
138.10593
137.20209
136.37599
136.94136
136.94136
136.56812
135.32007
135.03564
134.87335
134.12837
135.85985
134.54898
133.83307
134.48972
132.49522
132.49522
134.0584
131.14745
132.20755
130.24994
140.44724
133.83612
131.67667
131.04099
130.73166
129.96307
129.96307
129.0737
129.75046
128.50105
126.766815
126.82614
125.187645
127.74845
123.88202
123.72064
121.882996
121.882996
123.44857
120.74601
122.73606
122.51065
121.35245
120.239075
121.3218
117.62918
124.10922
122.41046
121.56034
121.265
119.71651
128.87933
141.30072
123.04866
123.30194
121.666
121.381134
120.52517
120.52517
118.51246
118.76878
113.76194
116.74526
115.560425
111.78072
112.57384
113.751114
110.37503
107.556595
107.556595
108.20497
108.739334
107.64681
110.1952
107.21032
106.74629
105.166916
104.4899
104.47644
102.87871
1

21.041264
22.358234
17.250313
20.205189
17.844234
16.835155
20.744616
22.449299
19.39373
20.635418
17.464256
22.196556
19.13887
