Basic seq2seq model with Tensorflow
======================

In [2]:
__author__ = 'Nicholas Tomlin'
__version__ = "CSLI Summer 2018 internship"

### Imports 
Tested with Tensorflow 1.8.0. Using the Dense layer for seq2seq inference decoder, which will be described below. We'll need to add the `src/models/` file to our Python path to import the base RNN model.

In [3]:
import numpy as np
import tensorflow as tf
import warnings
import random
from tensorflow.python.layers.core import Dense

In [4]:
import sys
sys.path.append('../src/models/')
from tf_rnn_classifier import TfRNNClassifier

### Basic seq2seq class definition
We build a single graph which includes embeddings, encoding, and two separate decoding functions. One decoding function is used during training, and the other is used for inference (prediction). 

In [11]:
class TfEncoderDecoder(TfRNNClassifier):
    def __init__(self, max_input_length=5, max_output_length=5, num_layers=2, **kwargs):
        self.max_input_length = max_input_length
        self.max_output_length = max_output_length
        self.num_layers = num_layers

        super(TfEncoderDecoder, self).__init__(**kwargs)

    def build_graph(self):
        self._define_embedding()
        self._init_placeholders()
        self._init_embedding()
        self.encoding_layer()
        self.decoding_layer()


    def _init_placeholders(self):
        self.encoder_inputs = tf.placeholder(
            shape=[None, None],
            dtype=tf.int32,
            name="encoder_inputs")

        self.encoder_lengths = tf.placeholder(
            shape=[None,],
            dtype=tf.int32,
            name="encoder_lengths")

        self.decoder_inputs= tf.placeholder(
            shape=[None, None],
            dtype=tf.int32,
            name="decoder_inputs")

        self.decoder_targets = tf.placeholder(
            shape=[None, None],
            dtype=tf.int32,
            name="decoder_targets")

        self.decoder_lengths = tf.placeholder(
            shape=[None,],
            dtype=tf.int32,
            name="decoder_lengths")

    def _init_embedding(self):
        self.embedded_encoder_inputs = tf.nn.embedding_lookup(self.embedding, self.encoder_inputs)
        self.embedded_decoder_inputs = tf.nn.embedding_lookup(self.embedding, self.decoder_inputs)

    def encoding_layer(self):
        encoder_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_dim, activation=self.hidden_activation)
        encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
            cell=encoder_cell,
            inputs=self.embedded_encoder_inputs,
            time_major=True,
            dtype=tf.float32,
            scope="encoding_layer")
        self.encoder_final_state = encoder_final_state

    def decoding_layer(self):
        self.decoding_training()
        self.decoding_inference()

    def decoding_training(self):
        self.decoder_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_dim, activation=self.hidden_activation)
        self.output_layer = Dense(
            self.vocab_size,
            kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))
        
        training_helper = tf.contrib.seq2seq.TrainingHelper(
            inputs=self.embedded_decoder_inputs,
            sequence_length=self.decoder_lengths,
            time_major=False)
        
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
            self.decoder_cell,
            training_helper,
            self.encoder_final_state,
            self.output_layer)
        
        training_decoder_output = tf.contrib.seq2seq.dynamic_decode(
            training_decoder,
            impute_finished=True,
            maximum_iterations=self.max_output_length)[0]
        
        self.training_outputs = training_decoder_output
        self.training_logits = training_decoder_output.rnn_output
#         self.decoder_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_dim, activation=self.hidden_activation)

#         decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
#             self.decoder_cell,
#             self.embedded_decoder_inputs,
#             initial_state=self.encoder_final_state,
#             time_major=True,
#             dtype=tf.float32,
#             scope="decoding_layer")
#         decoder_logits = tf.contrib.layers.linear(decoder_outputs, self.vocab_size)
        
#         self.training_outputs = decoder_outputs
#         self.training_logits = decoder_logits


    def decoding_inference(self):
        start_tokens = tf.tile(
            input=tf.constant([2], dtype=tf.int32), # TODO: don't hardcode start token like this (2)
            multiples=[self.batch_size],
            name='start_tokens')

        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            embedding=self.embedding,
            start_tokens=start_tokens,
            end_token=3) # TODO: don't hardcode end token like this (3)

        inference_decoder = tf.contrib.seq2seq.BasicDecoder(
            self.decoder_cell,
            helper,
            self.encoder_final_state,
            self.output_layer)

        inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(
            inference_decoder,
            impute_finished=True,
            maximum_iterations=self.max_output_length)[0]

        self.inference_decoder_output = inference_decoder_output 
        self.inference_logits = inference_decoder_output.sample_id


    def prepare_output_data(self, y):
        return y


    def get_cost_function(self, **kwargs):
        return tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=self.training_logits,
                labels=tf.one_hot(self.decoder_targets, depth=self.vocab_size, dtype=tf.float32)))


    def predict(self, X):
        X, x_lengths = self._convert_X(X)
        length = X.shape[1]
        X.resize((self.batch_size, length))
        answer_logits = self.sess.run(self.inference_logits, {self.encoder_inputs: X, 
                                      self.decoder_lengths: [5]*self.batch_size, 
                                      self.encoder_lengths: [5]*self.batch_size})[0] 

        return predictions


    def train_dict(self, X, y):
        decoder_inputs = [["<GO>"] + list(seq) for seq in y]
        decoder_targets = [list(seq) + ["<EOS>"] for seq in y]

        encoder_inputs, encoder_lengths = self._convert_X(X)
        decoder_inputs, decoder_lengths = self._convert_X(decoder_inputs)
        decoder_targets, _ = self._convert_X(decoder_targets)
        return {self.encoder_inputs: encoder_inputs,
            self.decoder_inputs: decoder_inputs,
            self.decoder_targets: decoder_targets,
            self.encoder_lengths: encoder_lengths,
            self.decoder_lengths: decoder_lengths}

### Simple test dataset
Generate a dataset of "ab" strings that translates "a" to "b" and vice versa. For example:
 * "aaab" -> "bbba"
 * "bb" -> "aa"

Also need to define the vocab set. The superclass `TfModelBase` will take care of preprocessing.

In [7]:
vocab = ['<PAD>', '$UNK', '<GO>', '<EOS>', 'a', 'b']

train = []
for i in range(100):
    input_string = ""
    output_string = ""
    length = random.randint(1,5)
    for char in range(length):
        if (random.random() > 0.5):
            input_string += "a"
            output_string += "b"
        else:
            input_string += "b"
            output_string += "a"
        train.append([np.asarray(list(input_string)), np.asarray(list(output_string))])

In [8]:
train[:5]

[[array(['b'], dtype='<U1'), array(['a'], dtype='<U1')],
 [array(['b', 'b'], dtype='<U1'), array(['a', 'a'], dtype='<U1')],
 [array(['b', 'b', 'a'], dtype='<U1'), array(['a', 'a', 'b'], dtype='<U1')],
 [array(['b', 'b', 'a', 'a'], dtype='<U1'),
  array(['a', 'a', 'b', 'b'], dtype='<U1')],
 [array(['a'], dtype='<U1'), array(['b'], dtype='<U1')]]

In [9]:
test = [[np.asarray(list('ab')), np.asarray(list('ba'))],
        [np.asarray(list('ba')), np.asarray(list('ab'))]]

Now we can instantiate the class and test it:

In [12]:
seq2seq = TfEncoderDecoder(
    vocab=vocab, max_iter=10, max_length=5, eta=0.1)

X, y = zip(*train);
seq2seq.fit(X, y);

InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [308,50] vs. shape[1] = [5,50]
	 [[Node: decoder/while/BasicDecoderStep/lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_5, decoder/while/Identity_4, decoder/while/BasicDecoderStep/lstm_cell/concat/axis)]]

Caused by op 'decoder/while/BasicDecoderStep/lstm_cell/concat', defined at:
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 127, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/asyncio/base_events.py", line 422, in run_forever
    self._run_once()
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/asyncio/base_events.py", line 1432, in _run_once
    handle._run()
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/asyncio/events.py", line 145, in _run
    self._callback(*self._args)
  File "/usr/local/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 117, in _handle_events
    handler_func(fileobj, events)
  File "/usr/local/lib/python3.6/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2909, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-12-f322b139874c>", line 5, in <module>
    seq2seq.fit(X, y);
  File "../src/models/tf_model_base.py", line 113, in fit
    self.build_graph()
  File "<ipython-input-11-2e9107f91acc>", line 14, in build_graph
    self.decoding_layer()
  File "<ipython-input-11-2e9107f91acc>", line 58, in decoding_layer
    self.decoding_training()
  File "<ipython-input-11-2e9107f91acc>", line 81, in decoding_training
    maximum_iterations=self.max_output_length)[0]
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 304, in dynamic_decode
    swap_memory=swap_memory)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3224, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2956, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2893, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 249, in body
    decoder_finished) = decoder.step(time, inputs, state)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 137, in step
    cell_outputs, cell_state = self._cell(inputs, state)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 339, in __call__
    *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 717, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 833, in call
    array_ops.concat([inputs, m_prev], 1), self._kernel)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1189, in concat
    return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 953, in concat_v2
    "ConcatV2", values=values, axis=axis, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [308,50] vs. shape[1] = [5,50]
	 [[Node: decoder/while/BasicDecoderStep/lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_5, decoder/while/Identity_4, decoder/while/BasicDecoderStep/lstm_cell/concat/axis)]]


In [85]:
X_test, _ = zip(*test)
print('\nPredictions:', seq2seq.predict(X_test))

5
Tensor("decoder/transpose_1:0", shape=(1028, ?), dtype=int32)


InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [1028,50] vs. shape[1] = [5,50]
	 [[Node: decoding_layer/while/lstm_cell/concat_1 = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_5, decoder/while/Identity_4, decoding_layer/while/lstm_cell/concat_1/axis)]]

Caused by op 'decoding_layer/while/lstm_cell/concat_1', defined at:
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 127, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/asyncio/base_events.py", line 422, in run_forever
    self._run_once()
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/asyncio/base_events.py", line 1432, in _run_once
    handle._run()
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/asyncio/events.py", line 145, in _run
    self._callback(*self._args)
  File "/usr/local/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 117, in _handle_events
    handler_func(fileobj, events)
  File "/usr/local/lib/python3.6/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2909, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-84-f322b139874c>", line 5, in <module>
    seq2seq.fit(X, y);
  File "../src/models/tf_model_base.py", line 113, in fit
    self.build_graph()
  File "<ipython-input-83-a2fae4ac11b9>", line 14, in build_graph
    self.decoding_layer()
  File "<ipython-input-83-a2fae4ac11b9>", line 59, in decoding_layer
    self.decoding_inference()
  File "<ipython-input-83-a2fae4ac11b9>", line 101, in decoding_inference
    maximum_iterations=self.max_output_length)[0]
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 304, in dynamic_decode
    swap_memory=swap_memory)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3224, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2956, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2893, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 249, in body
    decoder_finished) = decoder.step(time, inputs, state)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 137, in step
    cell_outputs, cell_state = self._cell(inputs, state)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 339, in __call__
    *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 717, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 833, in call
    array_ops.concat([inputs, m_prev], 1), self._kernel)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1189, in concat
    return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 953, in concat_v2
    "ConcatV2", values=values, axis=axis, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1028,50] vs. shape[1] = [5,50]
	 [[Node: decoding_layer/while/lstm_cell/concat_1 = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](decoder/while/Identity_5, decoder/while/Identity_4, decoding_layer/while/lstm_cell/concat_1/axis)]]
