In [1]:
import tensorflow as tf
sess = tf.InteractiveSession()

In [2]:
import melt

tensorflow_version: 0.12.0-rc0


In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# We disable pylint because we need python3 compatibility.
from six.moves import xrange  # pylint: disable=redefined-builtin
from six.moves import zip     # pylint: disable=redefined-builtin

from tensorflow.python import shape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest

# TODO(ebrevdo): Remove once _linear is fully deprecated.
linear = rnn_cell._linear  # pylint: disable=protected-access

In [4]:
def attention_decoder(decoder_inputs,
                      initial_state,
                      attention_states,
                      cell,
                      output_size=None,
                      num_heads=1,
                      loop_function=None,
                      dtype=None,
                      scope=None,
                      initial_state_attention=False):
  """RNN decoder with attention for the sequence-to-sequence model.

  In this context "attention" means that, during decoding, the RNN can look up
  information in the additional tensor attention_states, and it does this by
  focusing on a few entries from the tensor. This model has proven to yield
  especially good results in a number of sequence-to-sequence tasks. This
  implementation is based on http://arxiv.org/abs/1412.7449 (see below for
  details). It is recommended for complex sequence-to-sequence tasks.

  Args:
    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
    initial_state: 2D Tensor [batch_size x cell.state_size].
    attention_states: 3D Tensor [batch_size x attn_length x attn_size].
    cell: rnn_cell.RNNCell defining the cell function and size.
    output_size: Size of the output vectors; if None, we use cell.output_size.
    num_heads: Number of attention heads that read from attention_states.
    loop_function: If not None, this function will be applied to i-th output
      in order to generate i+1-th input, and decoder_inputs will be ignored,
      except for the first element ("GO" symbol). This can be used for decoding,
      but also for training to emulate http://arxiv.org/abs/1506.03099.
      Signature -- loop_function(prev, i) = next
        * prev is a 2D Tensor of shape [batch_size x output_size],
        * i is an integer, the step number (when advanced control is needed),
        * next is a 2D Tensor of shape [batch_size x input_size].
    dtype: The dtype to use for the RNN initial state (default: tf.float32).
    scope: VariableScope for the created subgraph; default: "attention_decoder".
    initial_state_attention: If False (default), initial attentions are zero.
      If True, initialize the attentions from the initial state and attention
      states -- useful when we wish to resume decoding from a previously
      stored decoder state and attention states.

  Returns:
    A tuple of the form (outputs, state), where:
      outputs: A list of the same length as decoder_inputs of 2D Tensors of
        shape [batch_size x output_size]. These represent the generated outputs.
        Output i is computed from input i (which is either the i-th element
        of decoder_inputs or loop_function(output {i-1}, i)) as follows.
        First, we run the cell on a combination of the input and previous
        attention masks:
          cell_output, new_state = cell(linear(input, prev_attn), prev_state).
        Then, we calculate new attention masks:
          new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
        and then we calculate the output:
          output = linear(cell_output, new_attn).
      state: The state of each decoder cell the final time-step.
        It is a 2D Tensor of shape [batch_size x cell.state_size].

  Raises:
    ValueError: when num_heads is not positive, there are no inputs, shapes
      of attention_states are not set, or input size cannot be inferred
      from the input.
  """
  if not decoder_inputs:
    raise ValueError("Must provide at least 1 input to attention decoder.")
  if num_heads < 1:
    raise ValueError("With less than 1 heads, use a non-attention decoder.")
  if attention_states.get_shape()[2].value is None:
    raise ValueError("Shape[2] of attention_states must be known: %s"
                     % attention_states.get_shape())
  if output_size is None:
    output_size = cell.output_size

  with variable_scope.variable_scope(
      scope or "attention_decoder", dtype=dtype) as scope:
    print('scope', scope, scope.name)
    x = melt.get_weights('abc', [1,3])
    print(x, x.name)
    dtype = scope.dtype

    batch_size = array_ops.shape(decoder_inputs[0])[0]  # Needed for reshaping.
    attn_length = attention_states.get_shape()[1].value
    if attn_length is None:
      attn_length = shape(attention_states)[1]
    attn_size = attention_states.get_shape()[2].value

    # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
    hidden = array_ops.reshape(
        attention_states, [-1, attn_length, 1, attn_size])
    hidden_features = []
    v = []
    attention_vec_size = attn_size  # Size of query vectors for attention.
    for a in xrange(num_heads):
      k = variable_scope.get_variable("AttnW_%d" % a,
                                      [1, 1, attn_size, attention_vec_size])
      print('k', k, k.name)
      hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
      v.append(
          variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))

    state = initial_state
    
    print('state.name', state[0].name)

    def attention(query):
      """Put attention masks on hidden using hidden_features and query."""
      ds = []  # Results of attention reads will be stored here.
      if nest.is_sequence(query):  # If the query is a tuple, flatten it.
        query_list = nest.flatten(query)
        print('query_list', query_list)
        for q in query_list:  # Check that ndims == 2 if specified.
          ndims = q.get_shape().ndims
          if ndims:
            assert ndims == 2
        query = array_ops.concat(1, query_list)
        print('query', query)
      for a in xrange(num_heads):
        with variable_scope.variable_scope("Attention_%d" % a):
          y = linear(query, attention_vec_size, True)
          y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
          print('y', y)
          print('hidden_features[0]', hidden_features[0])
          z = hidden_features[0] + y
          print('z',  z)
          # Attention mask is a softmax of v^T * tanh(...).
          s = math_ops.reduce_sum(
              v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3])
          print('s', s)
          a = nn_ops.softmax(s)
          print('a', a)
          # Now calculate the attention-weighted vector d.
          d = math_ops.reduce_sum(
              array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,
              [1, 2])
          print('d', d)
          ds.append(array_ops.reshape(d, [-1, attn_size]))
      return ds

    outputs = []
    prev = None
    batch_attn_size = array_ops.pack([batch_size, attn_size])
    attns = [array_ops.zeros(batch_attn_size, dtype=dtype)
             for _ in xrange(num_heads)]
    for a in attns:  # Ensure the second shape of attention vectors is set.
      a.set_shape([None, attn_size])
    if initial_state_attention:
      attns = attention(initial_state)
    for i, inp in enumerate(decoder_inputs):
      if i > 0:
        variable_scope.get_variable_scope().reuse_variables()
      # If loop_function is set, we use it instead of decoder_inputs.
      if loop_function is not None and prev is not None:
        with variable_scope.variable_scope("loop_function", reuse=True):
          inp = loop_function(prev, i)
      # Merge input and previous attentions into one vector of the right size.
      input_size = inp.get_shape().with_rank(2)[1]
      if input_size.value is None:
        raise ValueError("Could not infer input size from input: %s" % inp.name)
      x = linear([inp] + attns, input_size, True)
      # Run the RNN.
      cell_output, state = cell(x, state)
      # Run the attention mechanism.
      if i == 0 and initial_state_attention:
        with variable_scope.variable_scope(variable_scope.get_variable_scope(),
                                           reuse=True):
          attns = attention(state)
      else:
        attns = attention(state)

      with variable_scope.variable_scope("AttnOutputProjection"):
        output = linear([cell_output] + attns, output_size, True)
      if loop_function is not None:
        prev = output
      outputs.append(output)

  return outputs, state


In [5]:
cell = melt.create_rnn_cell(4, True)

In [6]:
vocab_size = 5
emb_dim = 6
init_width = 0.5 / emb_dim
#emb = melt.variable.init_weights_uniform([vocab_size, emb_dim], -init_width, init_width)
emb = tf.constant([[-0.0454044 ,  0.07558767,  0.06434789,  0.04944561,  0.04671062,
        -0.06196741],
       [ 0.04754589, -0.03475843, -0.03286489,  0.00497814, -0.05656481,
        -0.07599609],
       [-0.06159163, -0.00535063, -0.03759231, -0.04672422,  0.01091411,
         0.02889993],
       [-0.05034878, -0.04895053,  0.07128759,  0.04060432,  0.07238931,
         0.03129234],
       [-0.04462979, -0.00026041,  0.03161035, -0.01818546,  0.06576461,
         0.04641552]])

In [7]:
seq = tf.constant([[1,1,3,2, 3]])
inputs = tf.nn.embedding_lookup(emb, seq)
seq_length = melt.length(seq) - 1
encode_feature, state = melt.rnn.encode(
          cell, 
          inputs, 
          seq_length, 
          encode_method=0,
          output_method=3)

In [8]:
decoder_seq = tf.constant([[1,1,3,4]])
decoder_inputs = tf.nn.embedding_lookup(emb, decoder_seq)
initial_state = state
attention_state = encode_feature
decoder_inputs = [tf.squeeze(x, 1) for x in tf.split(1, 4, decoder_inputs)]
outputs, final_state = attention_decoder(decoder_inputs, initial_state, attention_state, cell)

scope <tensorflow.python.ops.variable_scope.VariableScope object at 0x8a66d90> attention_decoder
Tensor("attention_decoder/abc/read:0", shape=(1, 3), dtype=float32) attention_decoder/abc:0
k Tensor("attention_decoder/AttnW_0/read:0", shape=(1, 1, 4, 4), dtype=float32) attention_decoder/AttnW_0:0
state.name RNN/while/Exit_2:0
query_list [<tf.Tensor 'attention_decoder/LSTMCell/add_1:0' shape=(?, 4) dtype=float32>, <tf.Tensor 'attention_decoder/LSTMCell/mul_2:0' shape=(?, 4) dtype=float32>]
query Tensor("attention_decoder/concat:0", shape=(?, 8), dtype=float32)
y Tensor("attention_decoder/Attention_0/Reshape:0", shape=(?, 1, 1, 4), dtype=float32)
hidden_features[0] Tensor("attention_decoder/Conv2D:0", shape=(1, 5, 1, 4), dtype=float32)
z Tensor("attention_decoder/Attention_0/add_1:0", shape=(?, 5, 1, 4), dtype=float32)
s Tensor("attention_decoder/Attention_0/Sum:0", shape=(?, 5), dtype=float32)
a Tensor("attention_decoder/Attention_0/Softmax:0", shape=(?, 5), dtype=float32)
d Tensor("atte

In [9]:
#melt.reuse_variables()
tf.get_variable_scope().reuse_variables()
seq = tf.constant([[1,1,3,2, 3, 2]])
inputs = tf.nn.embedding_lookup(emb, seq)
seq_length = melt.length(seq) - 1
encode_feature, state = melt.rnn.encode(
          cell, 
          inputs, 
          seq_length, 
          encode_method=0,
          output_method=3)
print('encode_featre', encode_feature)
l = [1,1,3,4, 3]
decoder_seq = tf.constant([l])
decoder_inputs = tf.nn.embedding_lookup(emb, decoder_seq)
initial_state = state
attention_state = encode_feature
decoder_inputs = [tf.squeeze(x, 1) for x in tf.split(1, len(l), decoder_inputs)]
#melt.reuse_variables()

outputs, final_state = attention_decoder(decoder_inputs, initial_state, attention_state, cell)

encode_featre Tensor("RNN_1/transpose:0", shape=(1, 6, 4), dtype=float32)
scope <tensorflow.python.ops.variable_scope.VariableScope object at 0x8ef6150> attention_decoder
Tensor("attention_decoder/abc/read:0", shape=(1, 3), dtype=float32) attention_decoder/abc:0
k Tensor("attention_decoder/AttnW_0/read:0", shape=(1, 1, 4, 4), dtype=float32) attention_decoder/AttnW_0:0
state.name RNN_1/while/Exit_2:0
query_list [<tf.Tensor 'attention_decoder_1/LSTMCell/add_1:0' shape=(?, 4) dtype=float32>, <tf.Tensor 'attention_decoder_1/LSTMCell/mul_2:0' shape=(?, 4) dtype=float32>]
query Tensor("attention_decoder_1/concat:0", shape=(?, 8), dtype=float32)
y Tensor("attention_decoder_1/Attention_0/Reshape:0", shape=(?, 1, 1, 4), dtype=float32)
hidden_features[0] Tensor("attention_decoder_1/Conv2D:0", shape=(1, 6, 1, 4), dtype=float32)
z Tensor("attention_decoder_1/Attention_0/add_1:0", shape=(?, 6, 1, 4), dtype=float32)
s Tensor("attention_decoder_1/Attention_0/Sum:0", shape=(?, 6), dtype=float32)
a Ten

In [10]:
init_op = tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer())
sess.run(init_op)

In [11]:
decoder_inputs[0].eval()

array([[ 0.04754589, -0.03475843, -0.03286489,  0.00497814, -0.05656481,
        -0.07599609]], dtype=float32)

In [12]:
decoder_inputs[0].eval().shape

(1, 6)

In [13]:
inputs.eval()

array([[[ 0.04754589, -0.03475843, -0.03286489,  0.00497814, -0.05656481,
         -0.07599609],
        [ 0.04754589, -0.03475843, -0.03286489,  0.00497814, -0.05656481,
         -0.07599609],
        [-0.05034878, -0.04895053,  0.07128759,  0.04060432,  0.07238931,
          0.03129234],
        [-0.06159163, -0.00535063, -0.03759231, -0.04672422,  0.01091411,
          0.02889993],
        [-0.05034878, -0.04895053,  0.07128759,  0.04060432,  0.07238931,
          0.03129234],
        [-0.06159163, -0.00535063, -0.03759231, -0.04672422,  0.01091411,
          0.02889993]]], dtype=float32)

In [14]:
inputs.eval().shape

(1, 6, 6)

In [15]:
emb.eval()

array([[-0.0454044 ,  0.07558767,  0.06434789,  0.04944561,  0.04671062,
        -0.06196741],
       [ 0.04754589, -0.03475843, -0.03286489,  0.00497814, -0.05656481,
        -0.07599609],
       [-0.06159163, -0.00535063, -0.03759231, -0.04672422,  0.01091411,
         0.02889993],
       [-0.05034878, -0.04895053,  0.07128759,  0.04060432,  0.07238931,
         0.03129234],
       [-0.04462979, -0.00026041,  0.03161035, -0.01818546,  0.06576461,
         0.04641552]], dtype=float32)

In [16]:
encode_feature.eval()

array([[[  6.17767451e-03,  -7.36213569e-03,  -1.53179804e-03,
          -2.59350846e-03],
        [  9.83432122e-03,  -1.14350067e-02,  -2.33986042e-03,
          -5.12735406e-03],
        [  1.28925231e-03,  -9.34195053e-03,  -1.46743150e-05,
          -4.34792601e-03],
        [ -1.25635443e-02,  -6.64551603e-03,  -8.03328864e-03,
          -1.25908032e-02],
        [ -1.34159653e-02,  -7.22592324e-03,  -5.32426918e-03,
          -1.06535060e-02],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00]]], dtype=float32)

In [17]:
state

LSTMStateTuple(c=<tf.Tensor 'RNN_1/while/Exit_2:0' shape=(?, 4) dtype=float32>, h=<tf.Tensor 'RNN_1/while/Exit_3:0' shape=(?, 4) dtype=float32>)

In [18]:
cell.state_size

LSTMStateTuple(c=4, h=4)

In [19]:
state[0].eval()

array([[-0.0268254 , -0.01475926, -0.01074809, -0.02152603]], dtype=float32)

In [20]:
state[1].eval()

array([[-0.01341597, -0.00722592, -0.00532427, -0.01065351]], dtype=float32)

In [21]:
outputs

[<tf.Tensor 'attention_decoder_1/AttnOutputProjection/add:0' shape=(?, 4) dtype=float32>,
 <tf.Tensor 'attention_decoder_1/AttnOutputProjection_1/add:0' shape=(?, 4) dtype=float32>,
 <tf.Tensor 'attention_decoder_1/AttnOutputProjection_2/add:0' shape=(?, 4) dtype=float32>,
 <tf.Tensor 'attention_decoder_1/AttnOutputProjection_3/add:0' shape=(?, 4) dtype=float32>,
 <tf.Tensor 'attention_decoder_1/AttnOutputProjection_4/add:0' shape=(?, 4) dtype=float32>]

In [22]:
outputs[0].eval()

array([[-0.00701265, -0.00745942, -0.00434384,  0.00290774]], dtype=float32)

In [23]:
outputs[3].eval()

array([[-0.00697165, -0.00903419, -0.00193713,  0.00177378]], dtype=float32)