In [1]:
import os, sys
import tensorflow as tf
import numpy as np
from data import DataGenerator, process_batch

In [2]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

In [3]:
tf.__version__

'1.2.1'

In [4]:
pickle_file = 'content.pkl'
dataGen = DataGenerator(pickle_file)

In [5]:
PAD = 0
EOS = 1

#vocab_size = 10
vocab_size = dataGen.vocab_size + 1
input_embedding_size = 16

encoder_hidden_units = 16
decoder_hidden_units = encoder_hidden_units

In [6]:
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')
decoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_inputs')

In [7]:
decoder_inputs

<tf.Tensor 'decoder_inputs:0' shape=(?, ?) dtype=int32>

In [8]:
embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -1.0, 1.0), dtype=tf.float32)

In [9]:
embeddings

<tf.Variable 'Variable:0' shape=(12579, 16) dtype=float32_ref>

In [10]:
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)

In [11]:
encoder_inputs_embedded

<tf.Tensor 'embedding_lookup:0' shape=(?, ?, 16) dtype=float32>

In [12]:
encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)

encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
    encoder_cell, 
    encoder_inputs_embedded,
    dtype=tf.float32, 
    time_major=True,
)

In [13]:
encoder_final_state

LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_2:0' shape=(?, 16) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 16) dtype=float32>)

In [14]:
decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)

decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
    decoder_cell, 
    decoder_inputs_embedded,
    initial_state=encoder_final_state,
    dtype=tf.float32, 
    time_major=True, 
    scope="plain_decoder",
)

In [15]:
decoder_outputs

<tf.Tensor 'plain_decoder/TensorArrayStack/TensorArrayGatherV3:0' shape=(?, ?, 16) dtype=float32>

In [16]:
decoder_logits = tf.contrib.layers.linear(decoder_outputs, vocab_size)

In [17]:
decoder_logits

<tf.Tensor 'fully_connected/BiasAdd:0' shape=(?, ?, 12579) dtype=float32>

In [18]:
decoder_prediction = tf.argmax(decoder_logits, 2)

In [19]:
decoder_targets

<tf.Tensor 'decoder_targets:0' shape=(?, ?) dtype=int32>

In [20]:
stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),
    logits=decoder_logits,
)

loss = tf.reduce_mean(stepwise_cross_entropy)
train_op = tf.train.AdamOptimizer().minimize(loss)

In [21]:
sess.run(tf.global_variables_initializer())

In [22]:
batch_size = 200
batches = dataGen.generate_sequence(batch_size)

In [23]:
'''
batch_size = 200
batches = dataGen.generate_sequence(batch_size)
def next_feed(batch):
    encoder_inputs_, _ = process_batch(batch)
    decoder_targets_, _ = process_batch([(sequence) + [EOS] for sequence in batch])
    decoder_inputs_, _ = process_batch([[EOS] + (sequence) for sequence in batch])
    return (len(encoder_inputs_), len(decoder_targets_), len(decoder_inputs_))
'''

'\nbatch_size = 200\nbatches = dataGen.generate_sequence(batch_size)\ndef next_feed(batch):\n    encoder_inputs_, _ = process_batch(batch)\n    decoder_targets_, _ = process_batch([(sequence) + [EOS] for sequence in batch])\n    decoder_inputs_, _ = process_batch([[EOS] + (sequence) for sequence in batch])\n    return (len(encoder_inputs_), len(decoder_targets_), len(decoder_inputs_))\n'

In [24]:
def next_feed():
    batch = next(batches)
    encoder_inputs_, _ = process_batch(batch)
    decoder_targets_, _ = process_batch([(sequence) + [EOS] for sequence in batch])
    decoder_inputs_, _ = process_batch([[EOS] + (sequence) for sequence in batch])
    return {
        encoder_inputs: encoder_inputs_,
        decoder_inputs: decoder_inputs_,
        decoder_targets: decoder_targets_,
    }

In [25]:
loss_track = []

In [26]:
max_batches = 3001
batches_in_epoch = 1000

try:
    for batch in range(max_batches):
        fd = next_feed()
        _, state, l = sess.run([train_op, encoder_final_state, loss], fd)
        loss_track.append(l)
        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(decoder_prediction, fd)
            for i, (inp, pred) in enumerate(zip(fd[encoder_inputs].T, predict_.T)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
            print()
except KeyboardInterrupt:
    print('training interrupted')

batch 0
  minibatch loss: 9.43666934967
  sample 1:
    input     > [  294     8     9 10872     0]
    predicted > [6735 2396 2396 9045 7831  188]
  sample 2:
    input     > [  320    18 12135  3475     0]
    predicted > [ 9189 10339   299  6245  1305  1619]
  sample 3:
    input     > [1729  164   23 9374    0]
    predicted > [ 1427  1427  1427  1652    11 11136]
()
training interrupted
