In [6]:
import tensorflow as tf
sess = tf.InteractiveSession()
import numpy as np

In [7]:
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
max_time = 8
input_depth = 7
cell_depth = 10
max_out = max(sequence_length)

In [8]:
max_out

4

In [9]:
time_major = True
if time_major:
  inputs = np.random.randn(max_time, batch_size,
                           input_depth).astype(np.float32)
else:
  inputs = np.random.randn(batch_size, max_time,
                           input_depth).astype(np.float32)
inputs

array([[[ -1.63054931e+00,   4.56274897e-01,  -6.83186412e-01,
          -1.53135717e+00,   9.34157789e-01,  -3.99784036e-02,
          -6.25981092e-01],
        [  2.29047582e-01,   3.22035164e-01,  -7.21430242e-01,
          -5.32526672e-01,  -1.24769616e+00,   1.15178835e+00,
          -1.41922069e+00],
        [ -7.27448583e-01,  -1.41406834e+00,  -7.90897757e-02,
           2.60723799e-01,  -1.00787592e+00,  -1.44772291e-01,
          -2.69150082e-02],
        [ -1.21435535e+00,  -4.91854787e-01,   2.48019919e-01,
           3.17332000e-01,   2.45850340e-01,   6.85159326e-01,
          -2.48433605e-01],
        [  6.08725607e-01,   4.89112556e-01,  -7.40392148e-01,
          -7.26322711e-01,  -1.24961877e+00,   7.94758022e-01,
          -4.07460570e-01]],

       [[  8.10555160e-01,   7.84145951e-01,   1.07380617e+00,
           3.41595896e-02,  -2.10963845e+00,   3.98175985e-01,
           4.18650508e-02],
        [  3.23609382e-01,   9.17592719e-02,   9.60837483e-01,
          -

In [10]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

import six

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest

__all__ = ["Decoder", "dynamic_decode_rnn"]


def _transpose_batch_time(x):
  """Transpose the batch and time dimensions of a Tensor.

  Retains as much of the static shape information as possible.

  Args:
    x: A tensor of rank 2 or higher.

  Returns:
    x transposed along the first two dimensions.

  Raises:
    ValueError: if `x` is rank 1 or lower.
  """
  x_static_shape = x.get_shape()
  if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
    raise ValueError(
        "Expected input tensor %s to have rank at least 2, but saw shape: %s" %
        (x, x_static_shape))
  x_rank = array_ops.rank(x)
  x_t = array_ops.transpose(
      x, array_ops.concat_v2(
          ([1, 0], math_ops.range(2, x_rank)), axis=0))
  x_t.set_shape(
      tensor_shape.TensorShape([
          x_static_shape[1].value, x_static_shape[0].value
      ]).concatenate(x_static_shape[2:]))
  return x_t


@six.add_metaclass(abc.ABCMeta)
class Decoder(object):
  """An RNN Decoder abstract interface object."""

  @property
  def batch_size(self):
    """The batch size of the inputs returned by `sample`."""
    raise NotImplementedError

  @property
  def output_size(self):
    """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s]."""
    raise NotImplementedError

  @property
  def output_dtype(self):
    """A (possibly nested tuple of...) dtype[s]."""
    raise NotImplementedError

  @abc.abstractmethod
  def initialize(self, name=None):
    """Called before any decoding iterations.

    Args:
      name: Name scope for any created operations.

    Returns:
      `(finished, first_inputs, initial_state)`.
    """
    raise NotImplementedError

  @abc.abstractmethod
  def step(self, time, inputs, state):
    """Called per step of decoding (but only once for dynamic decoding).

    Args:
      time: Scalar `int32` tensor.
      inputs: Input (possibly nested tuple of) tensor[s] for this time step.
      state: State (possibly nested tuple of) tensor[s] from previous time step.

    Returns:
      `(outputs, next_state, next_inputs, finished)`.
    """
    raise NotImplementedError


def _create_zero_outputs(size, dtype, batch_size):
  """Create a zero outputs Tensor structure."""
  def _t(s):
    return (s if isinstance(s, ops.Tensor) else constant_op.constant(
        tensor_shape.TensorShape(s).as_list(),
        dtype=dtypes.int32,
        name="zero_suffix_shape"))

  def _create(s, d):
    return array_ops.zeros(
        array_ops.concat(
            ([batch_size], _t(s)), axis=0), dtype=d)

  return nest.map_structure(_create, size, dtype)


def dynamic_decode_rnn(decoder,
                       output_time_major=False,
                       parallel_iterations=32,
                       swap_memory=False):
  """Perform dynamic decoding with `decoder`.

  Args:
    decoder: A `Decoder` instance.
    output_time_major: Python boolean.  Default: `False` (batch major).  If
      `True`, outputs are returned as time major tensors (this mode is faster).
      Otherwise, outputs are returned as batch major tensors (this adds extra
      time to the computation).
    parallel_iterations: Argument passed to `tf.while_loop`.
    swap_memory: Argument passed to `tf.while_loop`.

  Returns:
    `(final_outputs, final_state)`.

  Raises:
    TypeError: if `decoder` is not an instance of `Decoder`.
  """
  if not isinstance(decoder, Decoder):
    raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                    type(decoder))

  zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype,
                                      decoder.batch_size)

  initial_finished, initial_inputs, initial_state = decoder.initialize()
  initial_time = constant_op.constant(0, dtype=dtypes.int32)

  def _shape(batch_size, from_shape):
    if not isinstance(from_shape, tensor_shape.TensorShape):
      return tensor_shape.TensorShape(None)
    else:
      batch_size = tensor_util.constant_value(
          ops.convert_to_tensor(
              batch_size, name="batch_size"))
      return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)

  def _create_ta(s, d):
    return tensor_array_ops.TensorArray(
        dtype=d, size=0, dynamic_size=True,
        element_shape=_shape(decoder.batch_size, s))

  initial_outputs_ta = nest.map_structure(
      _create_ta, decoder.output_size, decoder.output_dtype)

  def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                finished):
    return math_ops.logical_not(math_ops.reduce_all(finished))

  def body(time, outputs_ta, state, inputs, finished):
    """Internal while_loop body.

    Args:
      time: scalar int32 tensor.
      outputs_ta: structure of TensorArray.
      state: (structure of) state tensors and TensorArrays.
      inputs: (structure of) input tensors.
      finished: 1-D bool tensor.

    Returns:
      `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
    """
    (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
        time, inputs, state)
    next_finished = math_ops.logical_or(decoder_finished, finished)

    nest.assert_same_structure(state, decoder_state)
    nest.assert_same_structure(outputs_ta, next_outputs)
    nest.assert_same_structure(inputs, next_inputs)

    # Zero out output values past finish
    emit = nest.map_structure(
        lambda out, zero: array_ops.where(finished, zero, out), next_outputs,
        zero_outputs)

    # Copy through states past finish
    def _maybe_copy_state(new, cur):
      return (new if isinstance(cur, tensor_array_ops.TensorArray) else
              array_ops.where(finished, cur, new))

    next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
    outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    outputs_ta, emit)
    return (time + 1, outputs_ta, next_state, next_inputs, next_finished)

  res = control_flow_ops.while_loop(
      condition,
      body,
      loop_vars=[
          initial_time, initial_outputs_ta, initial_state, initial_inputs,
          initial_finished
      ],
      parallel_iterations=parallel_iterations,
      swap_memory=swap_memory)

  final_outputs_ta = res[1]
  final_state = res[2]

  final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
  if not output_time_major:
    final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)

  return final_outputs, final_state

In [11]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections

import six

from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest

__all__ = [
    "Sampler", "SamplingDecoderOutput", "BasicSamplingDecoder",
    "BasicTrainingSampler"
]

_transpose_batch_time = decoder._transpose_batch_time  # pylint: disable=protected-access
@six.add_metaclass(abc.ABCMeta)
class Sampler(object):

  @property
  def batch_size(self):
    pass

  @abc.abstractmethod
  def initialize(self):
    pass

  @abc.abstractmethod
  def sample(self, time, outputs, state):
    pass


class SamplingDecoderOutput(
    collections.namedtuple("SamplingDecoderOutput",
                           ("rnn_output", "sample_id"))):
  pass


class BasicSamplingDecoder(Decoder):
  """Basic sampling decoder."""

  def __init__(self, cell, sampler, initial_state):
    """Initialize BasicSamplingDecoder.

    Args:
      cell: An `RNNCell` instance.
      sampler: A `Sampler` instance.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.

    Raises:
      TypeError: if `cell` is not an instance of `RNNCell` or `sampler`
        is not an instance of `Sampler`.
    """
    if not isinstance(cell, core_rnn_cell.RNNCell):
      raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
    if not isinstance(sampler, Sampler):
      raise TypeError("sampler must be a Sampler, received: %s" %
                      type(sampler))
    self._cell = cell
    self._sampler = sampler
    self._initial_state = initial_state

  @property
  def batch_size(self):
    return self._sampler.batch_size

  @property
  def output_size(self):
    # Return the cell output and the id
    return SamplingDecoderOutput(
        rnn_output=self._cell.output_size,
        sample_id=tensor_shape.TensorShape([]))

  @property
  def output_dtype(self):
    # Assume the dtype of the cell is the output_size structure
    # containing the input_state's first component's dtype.
    # Return that structure and int32 (the id)
    dtype = nest.flatten(self._initial_state)[0].dtype
    return SamplingDecoderOutput(
        nest.map_structure(lambda _: dtype, self._cell.output_size),
        dtypes.int32)

  def initialize(self, name=None):
    return self._sampler.initialize() + (self._initial_state,)

  def step(self, time, inputs, state):
    """Perform a decoding step.

    Args:
      time: scalar `int32` tensor.
      inputs: A (structure of) input tensors.
      state: A (structure of) state tensors and TensorArrays.

    Returns:
      `(outputs, next_state, next_inputs, finished)`.
    """
    cell_outputs, next_state = self._cell(inputs, state)
    (sample_id, finished, next_inputs) = self._sampler.sample(
        time=time, outputs=cell_outputs, state=next_state)
    outputs = SamplingDecoderOutput(cell_outputs, sample_id)
    return (outputs, next_state, next_inputs, finished)

class BasicTrainingSampler(Sampler):
  """A (non-)sampler for use during training.  Only reads inputs."""

  def __init__(self, inputs, sequence_length, time_major=False):
    """Initializer.

    Args:
      inputs: A (structure of) input tensors.
      sequence_length: An int32 vector tensor.
      time_major: Python bool.

    Raises:
      ValueError: if `sequence_length` is not a 1D tensor.
    """
    inputs = ops.convert_to_tensor(inputs, name="inputs")
    if not time_major:
      inputs = nest.map_structure(_transpose_batch_time, inputs)

    def _unstack_ta(inp):
      return tensor_array_ops.TensorArray(
          dtype=inp.dtype, size=array_ops.shape(inp)[0],
          element_shape=inp.get_shape()[1:]).unstack(inp)

    self._input_tas = nest.map_structure(_unstack_ta, inputs)
    sequence_length = ops.convert_to_tensor(
        sequence_length, name="sequence_length")
    if sequence_length.get_shape().ndims != 1:
      raise ValueError(
          "Expected sequence_length to be a vector, but received shape: %s" %
          sequence_length.get_shape())
    self._sequence_length = sequence_length
    self._zero_inputs = nest.map_structure(
        lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
    self._batch_size = array_ops.size(sequence_length)

  @property
  def batch_size(self):
    return self._batch_size

  def initialize(self):
    finished = math_ops.equal(0, self._sequence_length)
    all_finished = math_ops.reduce_all(finished)
    next_inputs = control_flow_ops.cond(
        all_finished, lambda: self._zero_inputs,
        lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
    return (finished, next_inputs)

  def sample(self, time, **unused_kwargs):
    next_time = time + 1
    finished = (next_time >= self._sequence_length)
    all_finished = math_ops.reduce_all(finished)
    sample_id = array_ops.tile([constant_op.constant(-1)], [self._batch_size])
    next_inputs = control_flow_ops.cond(
        all_finished, lambda: self._zero_inputs,
        lambda: nest.map_structure(lambda inp: inp.read(next_time), self._input_tas))
    return (sample_id, finished, next_inputs)

In [12]:
cell = core_rnn_cell.LSTMCell(cell_depth)
sampler = BasicTrainingSampler(
    inputs, sequence_length, time_major=time_major)

In [13]:
sampler

<__main__.BasicTrainingSampler at 0x4315810>

In [14]:
sampler._sequence_length

<tf.Tensor 'sequence_length:0' shape=(5,) dtype=int32>

In [15]:
sampler._input_tas.read(7).eval()

array([[ 0.13792519,  1.66770911,  2.19124269, -0.51016611,  0.99686748,
         1.76268172,  1.88533998],
       [ 1.31990552,  0.58950371, -1.52749979,  0.41269293,  0.40469393,
         0.417514  ,  0.84630972],
       [ 0.29173258, -1.76017773, -0.62109584, -0.66316473,  0.41828188,
         0.38076872,  0.34394884],
       [-0.78217489,  2.46541548, -0.86606914,  1.17914689, -0.08871825,
        -0.89538765,  1.38839793],
       [-2.44864416,  0.95273793,  0.16966906, -1.02566195,  0.5478496 ,
        -0.36753356,  1.13364804]], dtype=float32)

In [16]:
array_ops.shape(inputs).eval()

array([8, 5, 7], dtype=int32)

In [17]:
sampler._zero_inputs.eval()

array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)

In [18]:
sampler.batch_size.eval()

5

In [19]:
finished = math_ops.equal(0, sampler._sequence_length)

In [20]:
finished.eval()

array([False, False, False, False,  True], dtype=bool)

In [21]:
all_finished = math_ops.reduce_all(finished)

In [22]:
all_finished.eval()

False

In [23]:
sample_id = array_ops.tile([constant_op.constant(-1)], [sampler._batch_size])

In [24]:
sample_id.eval()

array([-1, -1, -1, -1, -1], dtype=int32)

In [25]:
my_decoder = BasicSamplingDecoder(
    cell=cell,
    sampler=sampler,
    initial_state=cell.zero_state(
        dtype=dtypes.float32, batch_size=batch_size))

final_outputs, final_state = dynamic_decode_rnn(
    my_decoder, output_time_major=time_major)

In [26]:
final_outputs

SamplingDecoderOutput(rnn_output=<tf.Tensor 'TensorArrayStack/TensorArrayGatherV3:0' shape=(?, 5, 10) dtype=float32>, sample_id=<tf.Tensor 'TensorArrayStack_1/TensorArrayGatherV3:0' shape=(?, 5) dtype=int32>)

In [27]:
final_outputs.rnn_output[0]

<tf.Tensor 'strided_slice_2:0' shape=(5, 10) dtype=float32>

In [28]:
final_state

LSTMStateTuple(c=<tf.Tensor 'while/Exit_3:0' shape=(5, 10) dtype=float32>, h=<tf.Tensor 'while/Exit_4:0' shape=(5, 10) dtype=float32>)

In [29]:
final_state.c

<tf.Tensor 'while/Exit_3:0' shape=(5, 10) dtype=float32>