In [1]:
import tensorflow as tf
import numpy as np
import helpers

tf.reset_default_graph()
sess = tf.InteractiveSession()

  from ._conv import register_converters as _register_converters


In [2]:
PAD = 0
EOS = 1

vocab_size = 10
input_embedding_size = 20

encoder_hidden_units = 20
decoder_hidden_units =  encoder_hidden_units

In [3]:
#placeholders
encoder_inputs = tf.placeholder(shape=(None,None),dtype=tf.int32,name='encoder_inputs')
decoder_targets = tf.placeholder(shape=(None,None),dtype=tf.int32,name='decoder_target')
decoder_inputs = tf.placeholder(shape=(None,None),dtype=tf.int32,name='decoder_inputs')

In [4]:
#embeddings
embeddings = tf.Variable(tf.random_uniform(shape=(vocab_size,input_embedding_size),minval=-1,maxval=1,dtype=tf.float32))

In [5]:
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings,encoder_inputs)
decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings,decoder_inputs)

In [6]:
#Encoder
encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(cell=encoder_cell,inputs=encoder_inputs_embedded,dtype=tf.float32,time_major=True)
del encoder_outputs  # not needed, we only use the final hidden state as input to decoder


In [7]:
encoder_final_state # C is the state vector and h is the hidden sate vector

LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 20) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 20) dtype=float32>)

In [8]:
#Decoder
decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(cell=decoder_cell,inputs=decoder_inputs_embedded,
                                                         initial_state=encoder_final_state,
                                                         dtype=tf.float32,time_major=True,scope='plain_decoder'
                                                        )

In [9]:
#projection layer
decoder_logits = tf.contrib.layers.linear(inputs = decoder_outputs,num_outputs = vocab_size)
decoder_prediction = tf.argmax(decoder_logits,2)

In [10]:
decoder_logits

<tf.Tensor 'fully_connected/BiasAdd:0' shape=(?, ?, 10) dtype=float32>

In [11]:
stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
    labels = tf.one_hot(decoder_targets, depth=vocab_size,dtype=tf.float32),
    logits = decoder_logits
                       )
loss = tf.reduce_mean(stepwise_cross_entropy)
train_op = tf.train.AdamOptimizer().minimize(loss)

In [12]:
sess.run(tf.global_variables_initializer())

In [13]:
#TESTING FORWARD PASS
batch = [[6],
        [3,4],
        [9,8,7]]
batch, batch_length = helpers.batch(batch)

In [14]:
print(batch, 'batch length = ',batch_length)

[[6 3 9]
 [0 4 8]
 [0 0 7]] batch length =  [1, 2, 3]


In [15]:
din, dlen = helpers.batch(np.ones(shape=(3,1),dtype=np.int32),max_sequence_length=4)
print(din,'dlen : ',dlen)

[[1 1 1]
 [0 0 0]
 [0 0 0]
 [0 0 0]] dlen :  [1, 1, 1]


In [16]:
pred = sess.run(decoder_prediction,feed_dict={encoder_inputs:batch,decoder_inputs:din})
print('decoder predictions \n',str(pred))

decoder predictions 
 [[3 5 3]
 [3 5 6]
 [0 0 6]
 [0 5 5]]


In [17]:
#Training now 
batch_size = 100

batches = helpers.random_sequences(length_from=3, length_to=8,
                                   vocab_lower=2, vocab_upper=10,
                                   batch_size=batch_size)    #0,and 1 are set to eos and pad

print('head of the batch:')
for seq in next(batches)[:10]:
    print(seq)

head of the batch:
[7, 3, 6, 4, 8, 5, 2, 8]
[9, 6, 8, 5, 3, 2, 6]
[5, 3, 7, 3, 8, 7, 2]
[9, 5, 4, 6, 4, 7, 9]
[7, 4, 2, 8]
[9, 6, 9, 3, 7]
[8, 5, 6, 3]
[2, 9, 5, 3, 9]
[3, 6, 5, 7, 4]
[6, 7, 8, 9]


In [20]:
def next_feed():
    batch = next(batches)
    encoder_inputs_, _ = helpers.batch(batch)
    decoder_targets_, _ = helpers.batch(
        [(sequence) + [EOS] for sequence in batch]
    )
    decoder_inputs_, _ = helpers.batch(
        [[EOS] + (sequence) for sequence in batch]
    )
    return {
        encoder_inputs: encoder_inputs_,
        decoder_inputs: decoder_inputs_,
        decoder_targets: decoder_targets_,
    }
#Given encoder_inputs [5, 6, 7], decoder_targets would be [5, 6, 7, 1],
# where 1 is for EOS, and decoder_inputs would be [1, 5, 6, 7] - decoder_inputs are lagged by 1 step, passing previous token as input at current step.

In [18]:
loss_track = []

In [21]:
max_batches = 3001
batches_in_epoch = 1000

try:
    for batch in range(max_batches):
        fd = next_feed()
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)

        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(decoder_prediction, fd)
            for i, (inp, pred) in enumerate(zip(fd[encoder_inputs].T, predict_.T)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
            print()
except KeyboardInterrupt:
    print('training interrupted')

batch 0
  minibatch loss: 2.2818610668182373
  sample 1:
    input     > [8 3 6 3 5 4 0 0]
    predicted > [5 5 3 3 3 3 3 3 5]
  sample 2:
    input     > [9 3 3 7 8 7 7 0]
    predicted > [3 3 3 3 3 3 1 1 1]
  sample 3:
    input     > [6 4 2 6 7 6 0 0]
    predicted > [3 3 3 3 6 3 6 6 6]

batch 1000
  minibatch loss: 0.34204772114753723
  sample 1:
    input     > [9 9 9 7 5 3 6 2]
    predicted > [9 9 9 7 5 3 6 6 1]
  sample 2:
    input     > [2 5 6 2 4 4 0 0]
    predicted > [2 5 6 4 4 4 1 0 0]
  sample 3:
    input     > [3 4 3 4 8 2 8 4]
    predicted > [3 4 3 4 8 4 4 4 1]

batch 2000
  minibatch loss: 0.1506425142288208
  sample 1:
    input     > [9 7 3 7 5 7 0 0]
    predicted > [9 7 7 7 5 7 1 0 0]
  sample 2:
    input     > [9 7 8 3 5 3 0 0]
    predicted > [9 7 8 3 5 3 1 0 0]
  sample 3:
    input     > [7 4 2 5 2 6 5 0]
    predicted > [7 4 2 5 2 6 5 1 0]

batch 3000
  minibatch loss: 0.13755765557289124
  sample 1:
    input     > [3 8 8 8 2 0 0 0]
    predicted > [8 8 8