# Train the phylogenetic embeddings (seq2seq)

Import libraries, set up global constants

In [None]:
import tensorflow as tf
from model.seq2seq import basic_seq2seq
import numpy as np
from utils.text_processing import load_dict_from_vocab_file

vocab_file = './data/character_inventory_unk.txt'
traindb_file = './data/training.npz'
testdb_file = './data/testing.npz'
checkpoint_file = './tfmodel/gru_enc/model_%d.tfmodel'
log_dir = './tb'
log_interval = 10

Define model and training constants

In [None]:
lr = 0.001
l2reg = 0.01
keep_prob=1.0
batch_size_val = 64
vocab = load_dict_from_vocab_file(vocab_file)
vocab_size = len(vocab)
lstm_dim = 500
n_epochs = 100

Define placeholders (model inputs)

In [None]:
with tf.name_scope("placeholders"):
    encoder_in = tf.placeholder(tf.int32, [None, None])
    encoder_lens = tf.placeholder(tf.int32, [None])
    batch_size = tf.placeholder(tf.int32)
    
    decoder_in = tf.placeholder(tf.int32, [None, None])
    decoder_lens = tf.placeholder(tf.int32, [None])
    labels = tf.placeholder(tf.int32, [None, None])

Create the seq2seq, get the relevant tensors

In [None]:
with tf.name_scope("model"):
    logits, _ = basic_seq2seq(encoder_in, encoder_lens, decoder_in, decoder_lens,
                                          vocab_size=vocab_size, batch_size=batch_size, lstm_type="gru",
                                          lstm_dim=lstm_dim, keep_prob=keep_prob, max_iterations=101)

Define the loss (crossent and L2)

In [None]:
with tf.name_scope("loss"):
    #labels_flat = tf.reshape(labels, [-1])
    #logits = tf.reshape(logits, [-1, vocab_size])
    crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
    
    train_loss = (tf.reduce_sum(crossent) / tf.cast(batch_size, tf.float32))
    
    reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    reg_loss = l2reg * sum(reg_losses)
    tv = tf.trainable_variables()
    reg_losses.extend([l2reg * tf.nn.l2_loss(v) for v in tv])

    loss = train_loss + reg_loss
    
    with tf.name_scope("logging"):
        tf.summary.scalar("train_loss", train_loss)

Define the optimizer (Adam, gradient clipping)

In [None]:
with tf.name_scope("optimizer"):
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(lr)
        global_step = tf.Variable(0, trainable=False)
        gvs = optimizer.compute_gradients(loss)
        capped_gvs = [(tf.clip_by_norm(grad, 5.), var) for grad, var in gvs]
        train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step)
    
        with tf.name_scope("logging"):
            for grad, var in capped_gvs:
                tf.summary.histogram(var.name + "_grads", grad)

Finish up with some logging hooks

In [None]:
with tf.name_scope("logging"):
    valid_loss_ph = tf.placeholder(tf.float32, name="validation_loss")
    
    tf.summary.scalar("Valid_loss", valid_loss_ph)

    #Add histograms for trainable variables
    for v in tf.trainable_variables():
        tf.summary.histogram(v.name, v)    

    log_op = tf.summary.merge_all()

writer = tf.summary.FileWriter(log_dir, graph=tf.get_default_graph())

In [None]:
saver = tf.train.Saver()

Main training loop. On each logging interval, we calculate the validation loss and save the weights when the loss improves.

In [None]:
data = np.load(traindb_file)
encoder_in_batch = data['enc_in']
encoder_len_batch = data['enc_lens']
decoder_in_batch = data['dec_in']
decoder_len_batch = data['dec_lens']
labels_batch = data['labels']

valid_data = np.load(testdb_file)
valid_encoder_in_batch = data['enc_in']
valid_encoder_len_batch = data['enc_lens']
valid_decoder_in_batch = data['dec_in']
valid_decoder_len_batch = data['dec_lens']
valid_labels_batch = data['labels']

n_examples = labels_batch.shape[0]
idx = np.arange(n_examples)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
i = 0
k = 0
min_loss_val = 10000.0
while i < n_epochs:
    print("EPOCH %d"%i)
    j = 0
    np.random.shuffle(idx)
    while j < n_examples:
        curr = idx[j:j+batch_size_val]
        batch_size_curr = len(curr)
        if k % log_interval == 0:
            #Get the validation loss
            valid_loss_val, logits_val = sess.run([loss, logits], feed_dict={encoder_in: valid_encoder_in_batch,
                                                       encoder_lens: valid_encoder_len_batch,
                                                       decoder_in: valid_decoder_in_batch,
                                                       decoder_lens: valid_decoder_len_batch,
                                                       labels: valid_labels_batch,
                                                       batch_size: valid_labels_batch.shape[0]})
            
            
            summary, _, loss_val, logits_val = sess.run([log_op, train_op, loss, logits], feed_dict={encoder_in: encoder_in_batch[curr],
                                                       encoder_lens: encoder_len_batch[curr],
                                                       decoder_in: decoder_in_batch[curr],
                                                       decoder_lens: decoder_len_batch[curr],
                                                       labels: labels_batch[curr],
                                                       batch_size: batch_size_curr,
                                                       valid_loss_ph: valid_loss_val})
            writer.add_summary(summary, k)
            
            if valid_loss_val < min_loss_val:
                print(valid_loss_val)
                min_loss_val = valid_loss_val
                saver.save(sess, checkpoint_file % k)
        else:
            _, loss_val, logits_val = sess.run([train_op, loss, logits], feed_dict={encoder_in: encoder_in_batch[curr],
                                                       encoder_lens: encoder_len_batch[curr],
                                                       decoder_in: decoder_in_batch[curr],
                                                       decoder_lens: decoder_len_batch[curr],
                                                       labels: labels_batch[curr],
                                                       batch_size: batch_size_curr})
        
        
        j += batch_size_val
        k += 1
    i += 1