Skip to content

Commit

Permalink
Add Flatten to core layers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 168254118
  • Loading branch information
fchollet authored and tensorflower-gardener committed Sep 11, 2017
1 parent a6223c0 commit 80ed8af
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_condition_tensor_asserts(self):
array_ops.placeholder(dtypes.float32, (5, None)),
array_ops.placeholder(dtypes.float32, (5, 1)))

with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (5, 2)),
array_ops.placeholder(dtypes.float32, (5)))
Expand Down
25 changes: 1 addition & 24 deletions tensorflow/contrib/layers/python/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,30 +1435,7 @@ def flatten(inputs,
"""
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_rank = inputs.get_shape().ndims
if (inputs_rank is None) or (inputs_rank < 2):
raise ValueError('Inputs must have a least 2 dimensions.')

inputs_shape = array_ops.shape(inputs)

batch_dim = array_ops.slice(inputs_shape, [0], [1])
spatial_dims = array_ops.slice(inputs_shape, [1], [inputs_rank - 1])

flat_spatial_dim = math_ops.reduce_prod(spatial_dims)
flat_spatial_dim = array_ops.expand_dims(flat_spatial_dim, 0)
flat_shape = array_ops.concat([batch_dim, flat_spatial_dim], 0)

outputs = array_ops.reshape(inputs, flat_shape)

# Attempt to propagate shape information, if it is defined.
input_shape = inputs.get_shape().as_list()
batch_dim, spatial_dims = input_shape[0], input_shape[1:]
if all(spatial_dims):
outputs.set_shape([batch_dim,
functools.reduce(lambda x, y: x * y, spatial_dims)])
else:
outputs.set_shape([batch_dim, None])

outputs = core_layers.flatten(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)


Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/layers/python/layers/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ def testInvalidRank(self):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5,)))
with self.assertRaisesRegexp(ValueError,
'must have a least 2 dimensions'):
'incompatible with the layer'):
_layers.flatten(inputs)

def testUnknownLastDim(self):
Expand Down
23 changes: 2 additions & 21 deletions tensorflow/python/keras/_impl/keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))


class Flatten(Layer):
class Flatten(tf_core_layers.Flatten, Layer):
"""Flattens the input. Does not affect the batch size.
Example:
Expand All @@ -472,26 +472,7 @@ class Flatten(Layer):
# now: model.output_shape == (None, 65536)
```
"""

def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = InputSpec(min_ndim=3)

def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if not all(input_shape[1:]):
raise ValueError('The shape of the input to "Flatten" '
'is not fully defined '
'(got ' + str(input_shape[1:]) + '. '
'Make sure to pass a complete "input_shape" '
'or "batch_input_shape" argument to the first '
'layer in your model.')
return tensor_shape.TensorShape([input_shape[0], np.prod(input_shape[1:])])

def call(self, inputs):
outputs = K.batch_flatten(inputs)
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
return outputs
pass


class RepeatVector(Layer):
Expand Down
62 changes: 62 additions & 0 deletions tensorflow/python/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs
Expand Down Expand Up @@ -337,6 +338,67 @@ def dropout(inputs,
return layer.apply(inputs, training=training)


class Flatten(base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Examples:
```
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
y = Flatten()(x)
# now `y` has shape `(None, 16)`
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
y = Flatten()(x)
# now `y` has shape `(None, None)`
```
"""

def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = base.InputSpec(min_ndim=2)

def call(self, inputs):
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
return outputs

def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
output_shape = [input_shape[0]]
if all(input_shape[1:]):
output_shape += [np.prod(input_shape[1:])]
else:
output_shape += [None]
return tensor_shape.TensorShape(output_shape)


def flatten(inputs, name=None):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
name: The name of the layer (string).
Returns:
Reshaped tensor.
Examples:
```
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
y = flatten(x)
# now `y` has shape `(None, 16)`
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
y = flatten(x)
# now `y` has shape `(None, None)`
```
"""
layer = Flatten(name=name)
return layer.apply(inputs)


# Aliases

FullyConnected = Dense
Expand Down
51 changes: 51 additions & 0 deletions tensorflow/python/layers/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,5 +391,56 @@ def testDynamicRate(self):
self.assertAllClose(np.ones((5, 5)), np_output)


class FlattenTest(test.TestCase):

def testCreateFlatten(self):
with self.test_session() as sess:
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
self.assertEqual(list(np_output.shape), [3, 6])
self.assertEqual(y.get_shape().as_list(), [None, 6])

x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))})
self.assertEqual(list(np_output.shape), [1, 12])
self.assertEqual(y.get_shape().as_list(), [1, 12])

def testComputeShape(self):
shape = core_layers.Flatten()._compute_output_shape((1, 2, 3, 2))
self.assertEqual(shape.as_list(), [1, 12])

shape = core_layers.Flatten()._compute_output_shape((None, 3, 2))
self.assertEqual(shape.as_list(), [None, 6])

shape = core_layers.Flatten()._compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])

def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')
self.assertEqual(y.get_shape().as_list(), [None, 6])

def testFlattenValueError(self):
x = array_ops.placeholder(shape=(None,), dtype='float32')
with self.assertRaises(ValueError):
core_layers.Flatten()(x)

def testFlattenUnknownAxes(self):
with self.test_session() as sess:
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
self.assertEqual(list(np_output.shape), [5, 6])
self.assertEqual(y.get_shape().as_list(), [5, None])

x = array_ops.placeholder(shape=(5, None, 2), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))})
self.assertEqual(list(np_output.shape), [5, 6])
self.assertEqual(y.get_shape().as_list(), [5, None])


if __name__ == '__main__':
test.main()
4 changes: 4 additions & 0 deletions tensorflow/python/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@@Dense
@@Dropout
@@Flatten
@@Conv1D
@@Conv2D
@@Conv3D
Expand All @@ -39,6 +40,7 @@
@@dense
@@dropout
@@flatten
@@conv1d
@@conv2d
@@conv3d
Expand Down Expand Up @@ -71,9 +73,11 @@
# Core layers.
from tensorflow.python.layers.core import Dense
from tensorflow.python.layers.core import Dropout
from tensorflow.python.layers.core import Flatten

from tensorflow.python.layers.core import dense
from tensorflow.python.layers.core import dropout
from tensorflow.python.layers.core import flatten

# Convolutional layers.
from tensorflow.python.layers.convolutional import SeparableConv2D
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.Flatten"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
Expand Down
118 changes: 118 additions & 0 deletions tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
path: "tensorflow.layers.Flatten"
tf_class {
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
name: "graph"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
}
8 changes: 8 additions & 0 deletions tensorflow/tools/api/golden/tensorflow.layers.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ tf_module {
name: "Dropout"
mtype: "<type \'type\'>"
}
member {
name: "Flatten"
mtype: "<type \'type\'>"
}
member {
name: "InputSpec"
mtype: "<type \'type\'>"
Expand Down Expand Up @@ -120,6 +124,10 @@ tf_module {
name: "dropout"
argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "flatten"
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "max_pooling1d"
argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], "
Expand Down

0 comments on commit 80ed8af

Please sign in to comment.