In [1]:
import tensorflow as tf
sess = tf.InteractiveSession()

In [2]:
import abc

import six

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest

In [41]:
def map_structure(func, *structure):
  """Applies `func` to each entry in `structure` and returns a new structure.

  Applies `func(x[0], x[1], ...)` where x[i] is an entry in
  `structure[i]`.  All structures in `structure` must have the same arity,
  and the return value will contain the results in the same structure.

  Args:
    func: A callable that acceps as many arguments are there are structures.
    *structure: scalar, or tuple or list of constructed scalars and/or other
      tuples/lists, or scalars.  Note: numpy arrays are considered scalars.

  Returns:
    A new structure with the same arity as `structure`, whose values correspond
    to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
    location in `structure[i]`.

  Raises:
    TypeError: If `func` is not callable or if the structures do not match
      each other by depth tree.
    ValueError: If no structure is provided or if the structures do not match
      each other by type.
  """
  if not callable(func):
    raise TypeError("func must be callable, got: %s" % func)

  if not structure:
    raise ValueError("Must provide at least one structure")

  print('structure', structure)
  for other in structure[1:]:
    print('other', other)
    nest.assert_same_structure(structure[0], other)

  flat_structure = [nest.flatten(s) for s in structure]
  print('flat_structure', flat_structure)
  entries = zip(*flat_structure)
  print('entries', entries)

  return nest.pack_sequence_as(
      structure[0], [func(*x) for x in entries])

In [42]:
def _create_zero_outputs(size, dtype, batch_size):
  """Create a zero outputs Tensor structure."""
  def _t(s):
    return (s if isinstance(s, ops.Tensor) else constant_op.constant(
        tensor_shape.TensorShape(s).as_list(),
        dtype=dtypes.int32,
        name="zero_suffix_shape"))

  def _create(s, d):
    return array_ops.zeros(
        array_ops.concat(
            ([batch_size], _t(s)), axis=0), dtype=d)

  #return nest.map_structure(_create, size, dtype)
  return map_structure(_create, size, dtype)
  #return _create(size, dtype)


In [33]:
nest.map_structure??

In [43]:
_create_zero_outputs(5, tf.float32, 16).eval()

('structure', (5, tf.float32))
('other', tf.float32)
('flat_structure', [[5], [tf.float32]])
('entries', [(5, tf.float32)])


array([[ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.]], dtype=float32)

In [45]:
_create_zero_outputs([5, 6], [tf.float32, tf.float32], 16)

('structure', ([5, 6], [tf.float32, tf.float32]))
('other', [tf.float32, tf.float32])
('flat_structure', [[5, 6], [tf.float32, tf.float32]])
('entries', [(5, tf.float32), (6, tf.float32)])


[<tf.Tensor 'zeros_8:0' shape=(16, 5) dtype=float32>,
 <tf.Tensor 'zeros_9:0' shape=(16, 6) dtype=float32>]

In [46]:
_create_zero_outputs([5, 6], [tf.float32, tf.float32], 16)[0].eval()

('structure', ([5, 6], [tf.float32, tf.float32]))
('other', [tf.float32, tf.float32])
('flat_structure', [[5, 6], [tf.float32, tf.float32]])
('entries', [(5, tf.float32), (6, tf.float32)])


array([[ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.]], dtype=float32)

In [47]:
_create_zero_outputs([5, 6], [tf.float32, tf.float32], 16)[1].eval()

('structure', ([5, 6], [tf.float32, tf.float32]))
('other', [tf.float32, tf.float32])
('flat_structure', [[5, 6], [tf.float32, tf.float32]])
('entries', [(5, tf.float32), (6, tf.float32)])


array([[ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)

In [11]:
tf.concat?

In [13]:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0).eval()

array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]], dtype=int32)

In [14]:
tf.concat(0, [t1, t2]).eval()

ValueError: Shapes (2, 2, 3) and () are incompatible

In [16]:
tf.concat_v2??