Skip to content

Commit

Permalink
Select the correct input channel and stride values within Conv{1,2,3}…
Browse files Browse the repository at this point in the history
…DTranspose.transpose() when the data_format is NC*.

Add more tests for Conv*Transpose transpose functionality.

PiperOrigin-RevId: 181724912
  • Loading branch information
Deepmind authored and diegolascasas committed Jan 29, 2018
1 parent 7865f86 commit 315719e
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 15 deletions.
44 changes: 36 additions & 8 deletions sonnet/python/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ def input_shape(self):
self._ensure_is_connected()
return self._input_shape

@property
def input_channels(self):
"""Returns the number of input channels."""
self._ensure_is_connected()
return self._input_channels

def clone(self, name=None):
"""Returns a cloned `_ConvND` module.
Expand Down Expand Up @@ -1110,6 +1116,12 @@ def input_shape(self):
self._ensure_is_connected()
return self._input_shape

@property
def input_channels(self):
"""Returns the number of input channels."""
self._ensure_is_connected()
return self._input_channels


class Conv1D(_ConvND, base.Transposable):
"""1D convolution module, including optional bias.
Expand Down Expand Up @@ -1337,12 +1349,17 @@ def transpose(self, name=None):
Returns:
`Conv1D` module.
"""

if name is None:
name = self.module_name + "_transpose"
return Conv1D(output_channels=lambda: self.input_shape[-1],

if self._data_format == DATA_FORMAT_NWC:
stride = self.stride[1:-1]
else: # self._data_format == DATA_FORMAT_NCW
stride = self.stride[2:]

return Conv1D(output_channels=lambda: self.input_channels,
kernel_shape=self.kernel_shape,
stride=(self._stride[2],),
stride=stride,
padding=self.padding,
use_bias=self._use_bias,
initializers=self.initializers,
Expand Down Expand Up @@ -1675,9 +1692,15 @@ def transpose(self, name=None):
"""
if name is None:
name = self.module_name + "_transpose"
return Conv2D(output_channels=lambda: self.input_shape[-1],

if self._data_format == DATA_FORMAT_NHWC:
stride = self.stride[1:-1]
else: # self._data_format == DATA_FORMAT_NCHW
stride = self.stride[2:]

return Conv2D(output_channels=lambda: self.input_channels,
kernel_shape=self.kernel_shape,
stride=self.stride[1:-1],
stride=stride,
padding=self.padding,
use_bias=self._use_bias,
initializers=self.initializers,
Expand Down Expand Up @@ -1904,12 +1927,17 @@ def __init__(self, output_channels, output_shape=None, kernel_shape=None,
# Implement Transposable interface
def transpose(self, name=None):
"""Returns transposed Conv3DTranspose module, i.e. a Conv3D module."""

if name is None:
name = self.module_name + "_transpose"
return Conv3D(output_channels=lambda: self.input_shape[-1],

if self._data_format == DATA_FORMAT_NDHWC:
stride = self.stride[1:-1]
else: # self._data_format == DATA_FORMAT_NCDHW
stride = self.stride[2:]

return Conv3D(output_channels=lambda: self.input_channels,
kernel_shape=self.kernel_shape,
stride=self.stride[1:-1],
stride=stride,
padding=self.padding,
use_bias=self._use_bias,
initializers=self.initializers,
Expand Down
245 changes: 238 additions & 7 deletions sonnet/python/modules/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,100 @@ def testInitializerMutation(self):

self.assertAllEqual(initializers, initializers_copy)

@parameterized.named_parameters(
("WithBias", True),
("WithoutBias", False))
def testTransposeNHWC(self, use_bias):
"""Test transpose for NHWC format."""

conv2_transpose = snt.Conv2DTranspose(
output_channels=5,
output_shape=(5, 4),
kernel_shape=3,
padding=snt.VALID,
stride=1,
name="conv2_transpose",
use_bias=use_bias,
data_format=conv.DATA_FORMAT_NHWC)
conv2 = conv2_transpose.transpose()

# Check kernel shapes, strides and padding match.
self.assertEqual(conv2_transpose.kernel_shape, conv2.kernel_shape)
self.assertEqual((1,) + conv2_transpose.stride[1:3] + (1,), conv2.stride)
self.assertEqual(conv2_transpose.padding, conv2.padding)

# Before conv2_transpose is connected, we cannot know how many
# `output_channels` conv1 should have.
err = "Variables in conv2_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
_ = conv2.output_channels

# After connection the number of `output_channels` is known.
batch_size = 32
in_height = 2
in_width = 3
in_channels = 4
x = tf.constant(np.random.randn(batch_size, in_height, in_width,
in_channels),
dtype=np.float32)
conv2_transpose(x)
self.assertEqual(in_channels, conv2.output_channels)

# However, even after connection, the `input_shape` of the forward
# convolution is not known until it is itself connected (i.e. it can be
# connected to a different shape input from the `output_shape` of the
# transpose convolution!)
err = "Variables in conv2_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv2_transpose.output_shape, conv2.input_shape)

@parameterized.named_parameters(
("WithBias", True),
("WithoutBias", False))
def testTransposeNCHW(self, use_bias):
"""Test transpose for NCHW format."""

conv2_transpose = snt.Conv2DTranspose(
output_channels=5,
output_shape=(5, 4),
kernel_shape=3,
padding=snt.VALID,
stride=1,
name="conv2_transpose",
use_bias=use_bias,
data_format=conv.DATA_FORMAT_NCHW)
conv2 = conv2_transpose.transpose()

# Check kernel shapes, strides and padding match.
self.assertEqual(conv2_transpose.kernel_shape, conv2.kernel_shape)
self.assertEqual((1,) + conv2_transpose.stride[1:3] + (1,), conv2.stride)
self.assertEqual(conv2_transpose.padding, conv2.padding)

# Before conv2_transpose is connected, we cannot know how many
# `output_channels` conv1 should have.
err = "Variables in conv2_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
_ = conv2.output_channels

# After connection the number of `output_channels` is known.
batch_size = 32
in_height = 2
in_width = 3
in_channels = 4
x = tf.constant(np.random.randn(batch_size, in_channels, in_height,
in_width),
dtype=np.float32)
conv2_transpose(x)
self.assertEqual(in_channels, conv2.output_channels)

# However, even after connection, the `input_shape` of the forward
# convolution is not known until it is itself connected (i.e. it can be
# connected to a different shape input from the `output_shape` of the
# transpose convolution!)
err = "Variables in conv2_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv2_transpose.output_shape, conv2.input_shape)


class Conv1DTest(parameterized.TestCase, tf.test.TestCase):

Expand Down Expand Up @@ -1845,9 +1939,10 @@ def testSharing(self, batch_size, in_length, in_channels, out_channels,
@parameterized.parameters(
*zip(batch_size, in_length, in_channels, out_channels, kernel_shape,
padding, use_bias, out_shape, stride_shape))
def testTranspose(self, batch_size, in_length, in_channels, out_channels,
kernel_shape, padding, use_bias, out_shape, stride_shape):
"""Test transpose."""
def testTransposeNWC(self, batch_size, in_length, in_channels, out_channels,
kernel_shape, padding, use_bias, out_shape,
stride_shape):
"""Test transpose for NWC format."""

conv1_transpose = snt.Conv1DTranspose(
output_channels=out_channels,
Expand All @@ -1856,12 +1951,13 @@ def testTranspose(self, batch_size, in_length, in_channels, out_channels,
padding=padding,
stride=stride_shape,
name="conv1_transpose",
use_bias=use_bias)
use_bias=use_bias,
data_format=conv.DATA_FORMAT_NWC)
conv1 = conv1_transpose.transpose()

# Check kernel shapes, strides and padding match.
self.assertEqual(conv1_transpose.kernel_shape, conv1.kernel_shape)
self.assertEqual((1, conv1_transpose.stride[2], 1), conv1.stride)
self.assertEqual((1, conv1_transpose.stride[1], 1), conv1.stride)
self.assertEqual(conv1_transpose.padding, conv1.padding)

# Before conv1_transpose is connected, we cannot know how many
Expand All @@ -1884,6 +1980,50 @@ def testTranspose(self, batch_size, in_length, in_channels, out_channels,
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv1_transpose.output_shape, conv1.input_shape)

@parameterized.parameters(
*zip(batch_size, in_length, in_channels, out_channels, kernel_shape,
padding, use_bias, out_shape, stride_shape))
def testTransposeNCW(self, batch_size, in_length, in_channels, out_channels,
kernel_shape, padding, use_bias, out_shape,
stride_shape):
"""Test transpose for NCW format."""

conv1_transpose = snt.Conv1DTranspose(
output_channels=out_channels,
output_shape=out_shape,
kernel_shape=kernel_shape,
padding=padding,
stride=stride_shape,
name="conv1_transpose",
use_bias=use_bias,
data_format=conv.DATA_FORMAT_NCW)
conv1 = conv1_transpose.transpose()

# Check kernel shapes, strides and padding match.
self.assertEqual(conv1_transpose.kernel_shape, conv1.kernel_shape)
self.assertEqual((1, 1, conv1_transpose.stride[2]), conv1.stride)
self.assertEqual(conv1_transpose.padding, conv1.padding)

# Before conv1_transpose is connected, we cannot know how many
# `output_channels` conv1 should have.
err = "Variables in conv1_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
conv1.output_channels # pylint: disable=pointless-statement

# After connection the number of `output_channels` is known.
x = tf.constant(np.random.randn(batch_size, in_channels, in_length),
dtype=np.float32)
conv1_transpose(x)
self.assertEqual(in_channels, conv1.output_channels)

# However, even after connection, the `input_shape` of the forward
# convolution is not known until it is itself connected (i.e. it can be
# connected to a different shape input from the `output_shape` of the
# transpose convolution!)
err = "Variables in conv1_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv1_transpose.output_shape, conv1.input_shape)

def testInitializerMutation(self):
"""Test that initializers are not mutated."""

Expand Down Expand Up @@ -3416,8 +3556,8 @@ def setUp(self):
self.kernel_shape_h = 5
self.kernel_shape_w = 7
self.stride_d = 1
self.stride_h = 1
self.stride_w = 1
self.stride_h = 2
self.stride_w = 3
self.padding = snt.SAME

self.in_shape = (self.batch_size, self.in_depth, self.in_height,
Expand Down Expand Up @@ -3544,6 +3684,97 @@ def testTransposition(self, use_bias):
self.assertEqual(net_transposed_output.get_shape(),
input_to_net.get_shape())

@parameterized.named_parameters(
("WithBias", True),
("WithoutBias", False))
def testTransposeNDHWC(self, use_bias):
"""Test transpose for NDHWC format."""

conv3_transpose = snt.Conv3DTranspose(
output_channels=self.out_channels,
output_shape=self.out_shape,
kernel_shape=self.kernel_shape,
padding=self.padding,
stride=self.strides,
name="conv3_transpose",
use_bias=use_bias,
data_format=conv.DATA_FORMAT_NDHWC)
conv3 = conv3_transpose.transpose()

# Check kernel shapes, strides and padding match.
self.assertEqual(conv3_transpose.kernel_shape, conv3.kernel_shape)
self.assertEqual((1,) + self.strides + (1,), conv3.stride)
self.assertEqual(conv3_transpose.padding, conv3.padding)

# Before conv3_transpose is connected, we cannot know how many
# `output_channels` conv1 should have.
err = "Variables in conv3_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
_ = conv3.output_channels

# After connection the number of `output_channels` is known.
x = tf.constant(np.random.randn(self.batch_size,
self.in_depth,
self.in_height,
self.in_width,
self.in_channels),
dtype=np.float32)
conv3_transpose(x)
self.assertEqual(self.in_channels, conv3.output_channels)

# However, even after connection, the `input_shape` of the forward
# convolution is not known until it is itself connected (i.e. it can be
# connected to a different shape input from the `output_shape` of the
# transpose convolution!)
err = "Variables in conv3_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv3_transpose.output_shape, conv3.input_shape)

@parameterized.named_parameters(
("WithBias", True),
("WithoutBias", False))
def testTransposeNCDHW(self, use_bias):
"""Test transpose for NCDHW format."""

conv3_transpose = snt.Conv3DTranspose(
output_channels=self.out_channels,
output_shape=self.out_shape,
kernel_shape=self.kernel_shape,
padding=self.padding,
stride=self.strides,
name="conv3_transpose",
use_bias=use_bias,
data_format=conv.DATA_FORMAT_NCDHW)
conv3 = conv3_transpose.transpose()

# Check kernel shapes, strides and padding match.
self.assertEqual(conv3_transpose.kernel_shape, conv3.kernel_shape)
self.assertEqual((1, 1) + self.strides, conv3.stride)
self.assertEqual(conv3_transpose.padding, conv3.padding)

# Before conv3_transpose is connected, we cannot know how many
# `output_channels` conv1 should have.
err = "Variables in conv3_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
_ = conv3.output_channels

# After connection the number of `output_channels` is known.
x = tf.constant(np.random.randn(self.batch_size,
self.in_channels,
self.in_depth,
self.in_height,
self.in_width),
dtype=np.float32)
conv3_transpose(x)
self.assertEqual(self.in_channels, conv3.output_channels)

# However, even after connection, the `input_shape` of the forward
# convolution is not known until it is itself connected (i.e. it can be
# connected to a different shape input from the `output_shape` of the
# transpose convolution!)
err = "Variables in conv3_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv3_transpose.output_shape, conv3.input_shape)

if __name__ == "__main__":
tf.test.main()

0 comments on commit 315719e

Please sign in to comment.