In [13]:
import tensorflow as tf
import numpy as np

In [14]:
def reset_graph():
    tf.reset_default_graph()

In [15]:
def single_rnn_cell( 
                 num_units,
                 unit_type='block',
                 forget_bias=1.0, 
                 dropout=0.0, 
                 mode=tf.contrib.learn.ModeKeys.TRAIN,
                 residual_connection=False, 
                 device_str=None, 
                 residual_fn=None):
    
    """Create an instance of a single RNN cell."""
    # dropout (= 1 - keep_prob) is set to 0 during eval and infer
    dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0

    # Cell Type
    if unit_type == "lstm":
        single_cell = tf.contrib.rnn.BasicLSTMCell(
            num_units,
            forget_bias=forget_bias)
    elif unit_type == "block":
        single_cell = tf.contrib.rnn.LSTMBlockCell(
            num_units,
            forget_bias=forget_bias)
    else:
        raise ValueError("Unknown unit type %s!" % unit_type)

    # Dropout (= 1 - keep_prob)
    if dropout > 0.0:
        single_cell = tf.contrib.rnn.DropoutWrapper(
            cell=single_cell, input_keep_prob=(1.0 - dropout))

    # Residual
    if residual_connection:
        single_cell = tf.contrib.rnn.ResidualWrapper(
            single_cell, residual_fn=residual_fn)

    return single_cell

In [None]:
def build_encoder(self, source, cell_units, vocab_size, embed_size):
    """Build an encoder."""

    with tf.variable_scope("encoder") as scope:
        dtype = scope.dtype
        
        with tf.device("/cpu:0"):
            embedding_encoder = tf.get_variable("encoder_embedding", [vocab_size, embed_size], tf.float32)
        
        # Look up embedding, emp_inp: [max_time, batch_size, num_units]
        encoder_emb_inp = tf.nn.embedding_lookup(
            embedding_encoder, source)

        # Encoder_outputs: [max_time, batch_size, num_units]
        cell = single_rnn_cell(cell_units)

        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
            cell,
            encoder_emb_inp,
            dtype=dtype,
            sequence_length=iterator.source_sequence_length,
            time_major=self.time_major,
            swap_memory=True)

    return encoder_outputs, encoder_state

In [23]:
reset_graph()

n_steps = 3
n_inputs = 3
n_neurons = 5

X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])

basic_cell = single_rnn_cell(num_units=n_neurons,
                             mode=tf.contrib.learn.ModeKeys.TRAIN,
                             dropout=0.0)

outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)

init = tf.global_variables_initializer()

In [24]:
X_batch = np.array([ 
    #t=0         t=1        t=2
    [[0, 1, 2], [9, 8, 7], [9, 8, 7]], # instance 0 
    [[3, 4, 5], [0, 0, 0], [9, 8, 7]], # instance 1 
    [[6, 7, 8], [6, 5, 4], [9, 8, 7]], # instance 2 
    [[9, 0, 1], [3, 2, 1], [9, 8, 7]], # instance 3
])

with tf.Session() as sess: 
    init.run()
    output_evals = sess.run([outputs], feed_dict={X: X_batch})
    print(output_evals[0])
    print(output_evals[0].shape)

[[[-1.3479695e-02  6.7658715e-02 -3.0765137e-01 -5.4829609e-02
    6.2276922e-02]
  [ 2.9041064e-05  1.2992992e-01 -8.3218265e-01 -2.7392870e-01
    5.0804899e-03]
  [ 3.8163100e-05  1.7623809e-01 -9.2611718e-01 -3.5887638e-01
    6.2055942e-03]]

 [[ 3.7190430e-03  1.5560420e-01 -6.5453416e-01 -1.0468533e-01
    2.2801867e-02]
  [ 1.6352600e-01 -1.1556581e-02 -2.9956692e-01 -1.1687529e-01
    3.5942845e-02]
  [ 2.9767742e-05  1.4950672e-02 -8.6729288e-01 -3.2105729e-01
    3.5315128e-03]]

 [[ 1.2835364e-04  6.2459774e-02 -7.3214710e-01 -6.2867366e-02
    3.1469280e-03]
  [ 1.4442343e-03  1.7277467e-01 -8.4790742e-01 -4.0091449e-01
    1.2158995e-02]
  [ 4.0405059e-05  2.3465310e-01 -9.4015574e-01 -4.2132229e-01
    4.6831854e-03]]

 [[ 7.0376499e-03  1.6003622e-01 -4.3922877e-03 -6.8581295e-01
   -6.2636421e-03]
  [ 5.9277855e-02  3.1569383e-01 -2.7005225e-01 -6.1199915e-01
    1.9020710e-02]
  [ 3.7701953e-05  4.7329491e-01 -8.1621540e-01 -5.7509887e-01
    3.0519522e-03]]]
(4, 3, 5