Skip to content

Commit

Permalink
Conv1DTranspose modules can accept input with undefined batch sizes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 160485877
  • Loading branch information
Deepmind authored and diegolascasas committed Jul 3, 2017
1 parent 71048f7 commit ef86d2e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
19 changes: 11 additions & 8 deletions sonnet/python/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,6 @@ def _build(self, inputs):
of dimensions.
base.IncompatibleShapeError: If the input tensor has an unknown
`input_channels`.
base.UnderspecifiedError: If the input tensor has unknown `batch_size`.
base.IncompatibleShapeError: If `output_shape` is not an integer or
iterable of length 1.
TypeError: If input Tensor dtype is not tf.float32.
Expand All @@ -1328,11 +1327,6 @@ def _build(self, inputs):
"Number of input channels must be known at module build time")
input_channels = self._input_shape[2]

if self._input_shape[0] is None:
raise base.UnderspecifiedError(
"Batch size must be known at module build time")
batch_size = self._input_shape[0]

if self._use_default_output_shape:
self._output_shape = (
lambda: _default_transpose_size(self._input_shape[1:-1], # pylint: disable=g-long-lambda
Expand Down Expand Up @@ -1369,8 +1363,12 @@ def _build(self, inputs):
partitioner=self._partitioners.get("w", None),
regularizer=self._regularizers.get("w", None))

tf_out_shape = ((batch_size, 1,) + self._output_shape +
(self.output_channels,))
batch_size = tf.expand_dims(tf.shape(inputs)[0], 0)
out_shape = (1, self.output_shape[0])
out_channels = (self.output_channels,)
out_shape_tuple = out_shape + out_channels
conv_output_shape = tf.convert_to_tensor(out_shape_tuple)
tf_out_shape = tf.concat([batch_size, conv_output_shape], 0)

# Add an extra dimension to the input - a height of 1.
inputs = tf.expand_dims(inputs, 1)
Expand All @@ -1392,6 +1390,11 @@ def _build(self, inputs):
# Remove the superfluous height dimension to return a 3D tensor.
outputs = tf.squeeze(outputs, [1])

# Set the tensor sizes in order for shape inference.
batch_size_value = inputs.get_shape()[0]
output_shape_value = ((batch_size_value,) + self.output_shape +
(self.output_channels,))
outputs.set_shape(output_shape_value)
return outputs

@property
Expand Down
16 changes: 10 additions & 6 deletions sonnet/python/modules/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,9 +1403,9 @@ def testKernelsNotSpecified(self):
@parameterized.Parameters(
*zip(out_channels, kernel_shape, padding, use_bias, in_shape, out_shape,
stride_shape))
def testMissingBatchSizeError(self, out_channels, kernel_shape, padding,
use_bias, in_shape, out_shape, stride_shape):
"""Error is thrown if the batch size is unknown at build time."""
def testMissingBatchSize(self, out_channels, kernel_shape, padding,
use_bias, in_shape, out_shape, stride_shape):
"""Check functionality with unknown batch size at build time."""

conv1 = snt.Conv1DTranspose(output_channels=out_channels,
output_shape=out_shape,
Expand All @@ -1417,9 +1417,13 @@ def testMissingBatchSizeError(self, out_channels, kernel_shape, padding,

# Pass in an image with its batch size set to `None`:
image = tf.placeholder(tf.float32, shape=(None,) + in_shape[1:])
error_msg = "Batch size must be known at module build time"
with self.assertRaisesRegexp(snt.UnderspecifiedError, error_msg):
conv1(image)
output = conv1(image)
self.assertTrue(output.get_shape().is_compatible_with(
[None, out_shape, out_channels]))

with self.test_session() as sess:
tf.global_variables_initializer().run()
sess.run(output, feed_dict={image: np.zeros((10,) + in_shape[1:])})

@parameterized.Parameters(
*zip(batch_size, in_length, in_channels, out_length, out_channels,
Expand Down

0 comments on commit ef86d2e

Please sign in to comment.