In [1]:
#Seq2Seq

In [2]:
import numpy as np        #For Mathematical Operations
import tensorflow as tf   #For ML
from tensorflow.contrib.rnn import LSTMStateTuple
import os #For fetching from directory

In [3]:
#Helper class to generate random batch of different sequence lengths
class Helper(object):
    
    def __init__(self, batch_size):
        self.batch_size = batch_size
        
    def generate_batch(self):
        batch=[]
        for _ in range(self.batch_size):
            size = np.random.randint(low=10,high=20)
            batch.append(np.random.randint(low=0,high=10,size=size))
        max_len = np.max([len(seq) for seq in batch ])
        return batch,max_len

In [4]:
#Reset graph
tf.reset_default_graph()

In [5]:
#Set Session
sess = tf.InteractiveSession()

In [6]:
#Constants
PAD = 0  #Padding at the end of each sequence
EOS = 1  #Token indicating end of sequence
n_batches = 3000 #Number of batches in epoch
batch_size= 50 #Batch_size
enc_vocab_size = 10 # vocab size for encoder inputs
dec_vocab_size = enc_vocab_size*2 - 1
embed_size = 20 #embedding size
encoder_hidden_units = 20 #Number of encoder hidden units
decoder_hidden_units = encoder_hidden_units*2 #Number of decoder hidden units

In [7]:
#Define placeholders
with tf.variable_scope('placeholders'):
    encoder_inputs = tf.placeholder(shape=(batch_size,None),dtype=tf.int32,
                                   name="encoder_inputs")
    decoder_inputs = tf.placeholder(shape=(batch_size,None),dtype=tf.int32,
                                   name="decoder_inputs")
    decoder_targets = tf.placeholder(shape=(batch_size,None),dtype=tf.int32,
                                    name='decoder_targets')

In [8]:
#Define embeddings
with tf.name_scope('embeddings'):
    enc_embed_matrix = tf.Variable(tf.random_uniform((enc_vocab_size,embed_size),-1,1),
                               dtype=tf.float32,name="enc_embed_matrix")
    dec_embed_matrix = tf.Variable(tf.random_uniform((dec_vocab_size,embed_size),-1,1),
                               dtype=tf.float32,name="dec_embed_matrix")
    encoder_embeddings = tf.nn.embedding_lookup(enc_embed_matrix,encoder_inputs)
    decoder_embeddings = tf.nn.embedding_lookup(dec_embed_matrix,decoder_inputs)

In [9]:
#Define encoder
with tf.variable_scope('encoder'):
    encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_hidden_units)
    encoder_initial_state = encoder_cell.zero_state(batch_size,tf.float32)
    ((encoder_outputs_fw,encoder_outputs_bw),
     (encoder_states_fw,encoder_states_bw)) = tf.nn.bidirectional_dynamic_rnn(
                                                cell_fw=encoder_cell,
                                                cell_bw=encoder_cell,
                                                inputs=encoder_embeddings,
                                                initial_state_fw=encoder_initial_state,
                                                initial_state_bw=encoder_initial_state)

In [10]:
encoder_final_states_c = tf.concat([encoder_states_fw.c,encoder_states_bw.c],1)
encoder_final_states_h = tf.concat([encoder_states_fw.h,encoder_states_bw.h],1)
encoder_final_states = LSTMStateTuple(c=encoder_final_states_c,
                                     h=encoder_final_states_h)

In [11]:
encoder_outputs = tf.concat([encoder_outputs_fw,encoder_outputs_bw],2)

In [12]:
#Define decoder
with tf.variable_scope('decoder'):
    decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(decoder_hidden_units)
    decoder_initial_state = encoder_final_states
    ((decoder_outputs_fw,decoder_outputs_bw),
    (decoder_states_fw,decoder_states_bw)) = tf.nn.bidirectional_dynamic_rnn(
                                            cell_fw=decoder_cell,
                                            cell_bw=decoder_cell,
                                            inputs=decoder_embeddings,
                                            initial_state_fw=decoder_initial_state,
                                            initial_state_bw=decoder_initial_state)

In [13]:
decoder_outputs = tf.concat([decoder_outputs_fw,decoder_outputs_bw],2)

In [14]:
decoder_logits = tf.contrib.layers.fully_connected(decoder_outputs,dec_vocab_size)

In [15]:
entropy = tf.nn.softmax_cross_entropy_with_logits(logits=decoder_logits,
                                                 labels=tf.cast(tf.one_hot(decoder_targets,dec_vocab_size),
                                                               tf.float32))

In [16]:
decoder_prediction = tf.argmax(decoder_logits,2)

In [17]:
loss = tf.reduce_mean(entropy)

In [18]:
with tf.name_scope('summaries'):
    tf.summary.scalar('loss',loss)
    tf.summary.histogram('loss', loss)
    summary_op = tf.summary.merge_all()

In [19]:
optimizer = tf.train.AdamOptimizer().minimize(loss)

In [20]:
helper = Helper(batch_size)

In [21]:
def next():
    batch,max_len = helper.generate_batch()
    encoder_inputs_ = [np.append(np.append(seq,[EOS]),[PAD]*(max_len-len(seq))) for seq in batch]
    decoder_inputs_ = [np.append(np.append([EOS],seq*2),[PAD]*(max_len-len(seq))) for seq in batch]
    decoder_targets_ = [np.append(np.append(seq*2,[EOS]),[PAD]*(max_len-len(seq))) for seq in batch]
    return {encoder_inputs:encoder_inputs_,
           decoder_inputs:decoder_inputs_,
           decoder_targets:decoder_targets_}

In [22]:
sess.run(tf.global_variables_initializer())
loss_sum = 0.0
batches_in_epoch = 1000 
saver = tf.train.Saver()
#For tensorboard visualizations
writer = tf.summary.FileWriter('/graphs/seq2seq2', sess.graph)
#Check if checkpoint present 
ckpt = tf.train.get_checkpoint_state(os.path.dirname('/checkpoints/seq2seq2/checkpoint'))
#Restore the latest checkpoint if present
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
try:
    for batch in range(n_batches):
        fd = next()
        _, loss_val,summary = sess.run([optimizer, loss,summary_op], fd)
        loss_sum += loss_val

        if batch == 0 or batch % batches_in_epoch == 0:
            sess.run(tf.shape(decoder_outputs),fd)
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(loss_val))
            predict_ = sess.run(decoder_prediction, fd)
            for i, (inp,dec, pred) in enumerate(zip(fd[encoder_inputs],fd[decoder_inputs], predict_)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    decoder input  > {}'.format(dec))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
                print()
            saver.save(sess, '/checkpoints/seq2seq2/seq2seq2', batch)
               
except KeyboardInterrupt:
    print('training interrupted')

batch 0
  minibatch loss: 2.9513587951660156
  sample 1:
    input     > [2 3 4 6 1 0 8 3 4 6 6 2 0 4 1 6 2 1 0 0]
    decoder input  > [ 1  4  6  8 12  2  0 16  6  8 12 12  4  0  8  2 12  4  0  0]
    predicted > [18 15 15 15 12 15 15 12 15 12 12 12 12 12 12 12 12 12 12  7]

  sample 2:
    input     > [0 9 5 6 9 1 7 4 3 4 2 3 1 0 0 0 0 0 0 0]
    decoder input  > [ 1  0 18 10 12 18  2 14  8  6  8  4  6  0  0  0  0  0  0  0]
    predicted > [18 18  0  0 12 10 10 10 12 13 13 14 15 15 15 15  0  0  0  0]

  sample 3:
    input     > [3 8 2 6 5 7 7 0 1 8 4 6 4 3 4 1 0 0 0 0]
    decoder input  > [ 1  6 16  4 12 10 14 14  0  2 16  8 12  8  6  8  0  0  0  0]
    predicted > [18  5  5 18  0  0 10 10 15 15  3 12 12 12 12 12 14 14 14 10]
batch 1000
  minibatch loss: 0.01646534726023674
  sample 1:
    input     > [ 4.  1.  3.  8.  5.  0.  4.  4.  9.  6.  7.  8.  7.  8.  9.  3.  3.  2.
  2.  1.]
    decoder input  > [  1.   8.   2.   6.  16.  10.   0.   8.   8.  18.  12.  14.  16.  14.  16.
  1