Skip to content

Commit

Permalink
Implement tf.nn.atrous_conv2d_transpose. Close bugs tensorflow#4668 and
Browse files Browse the repository at this point in the history
tensorflow#5300.

Change: 140759688
  • Loading branch information
tensorflower-gardener committed Dec 1, 2016
1 parent 4311645 commit 4347945
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 28 deletions.
92 changes: 64 additions & 28 deletions tensorflow/python/kernel_tests/atrous_conv2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,33 @@
import tensorflow as tf


class AtrousConv2DTest(tf.test.TestCase):

def _upsample_filters(self, filters, rate):
"""Upsamples the filters by a factor of rate along the spatial dimensions.
def _upsample_filters(filters, rate):
"""Upsamples the filters by a factor of rate along the spatial dimensions.
Args:
filters: [h, w, in_depth, out_depth]. Original filters.
rate: An int, specifying the upsampling rate.
Returns:
filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with
h_up = h + (h - 1) * (rate - 1)
w_up = w + (w - 1) * (rate - 1)
containing (rate - 1) zeros between consecutive filter values along
the filters' spatial dimensions.
"""
if rate == 1:
return filters
# [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w]
filters_up = np.transpose(filters, [2, 3, 0, 1])
ker = np.zeros([rate, rate], dtype=np.float32)
ker[0, 0] = 1
filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)]
# [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth]
filters_up = np.transpose(filters_up, [2, 3, 0, 1])
return filters_up

Args:
filters: [h, w, in_depth, out_depth]. Original filters.
rate: An int, specifying the upsampling rate.

Returns:
filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with
h_up = h + (h - 1) * (rate - 1)
w_up = w + (w - 1) * (rate - 1)
containing (rate - 1) zeros between consecutive filter values along
the filters' spatial dimensions.
"""
if rate == 1:
return filters
# [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w]
filters_up = np.transpose(filters, [2, 3, 0, 1])
ker = np.zeros([rate, rate])
ker[0, 0] = 1
filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)]
# [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth]
filters_up = np.transpose(filters_up, [2, 3, 0, 1])
self.assertEqual(np.sum(filters), np.sum(filters_up))
return filters_up
class AtrousConv2DTest(tf.test.TestCase):

def testAtrousConv2DForward(self):
with self.test_session(use_gpu=True):
Expand All @@ -65,14 +65,13 @@ def testAtrousConv2DForward(self):
f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape)

for rate in range(1, 4):
f_up = self._upsample_filters(f, rate)
f_up = _upsample_filters(f, rate)

for padding in ["SAME", "VALID"]:
y1 = tf.nn.atrous_conv2d(x, f, rate, padding=padding)
y2 = tf.nn.conv2d(x, f_up, strides=[1, 1, 1, 1],
padding=padding)
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2,
atol=1e-2)
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3)

def testAtrousSequence(self):
"""Tests optimization of sequence of atrous convolutions.
Expand Down Expand Up @@ -150,5 +149,42 @@ def testGradient(self):
self.assertLess(err, err_tolerance)


class AtrousConv2DTransposeTest(tf.test.TestCase):

def testAtrousConv2DTransposeForward(self):
with self.test_session(use_gpu=True):
# Input: [batch, height, width, input_depth]
height = 9
for width in [9, 10]: # Test both odd and even width.
x_shape = [2, height, width, 2]
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)

# Filter: [kernel_height, kernel_width, input_depth, output_depth]
for kernel_height in range(1, 4):
for kernel_width in range(1, 4):
f_shape = [kernel_height, kernel_width, 2, 2]
f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape)

for rate in range(1, 4):
f_up = _upsample_filters(f, rate)
kernel_height_up = (kernel_height +
(kernel_height - 1) * (rate - 1))
kernel_width_up = kernel_width + (kernel_width - 1) * (rate - 1)

for padding in ["SAME", "VALID"]:
if padding == "SAME":
y_shape = [2, height, width, 2]
else:
y_shape = [2,
height + kernel_height_up - 1,
width + kernel_width_up - 1,
2]

y1 = tf.nn.atrous_conv2d_transpose(x, f, y_shape, rate, padding)
y2 = tf.nn.conv2d_transpose(
x, f_up, y_shape, strides=[1, 1, 1, 1], padding=padding)
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
tf.test.main()
1 change: 1 addition & 0 deletions tensorflow/python/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
@@depthwise_conv2d_native
@@separable_conv2d
@@atrous_conv2d
@@atrous_conv2d_transpose
@@conv2d_transpose
@@conv1d
@@conv3d
Expand Down
145 changes: 145 additions & 0 deletions tensorflow/python/ops/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,151 @@ def conv2d_transpose(value,
name=name)


def atrous_conv2d_transpose(value,
filters,
output_shape,
rate,
padding,
name=None):
"""The transpose of `atrous_conv2d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
actually the transpose (gradient) of `atrous_conv2d` rather than an actual
deconvolution.
Args:
value: A 4-D `Tensor` of type `float`. It needs to be in the default `NHWC`
format. Its shape is `[batch, in_height, in_width, in_channels]`.
filters: A 4-D `Tensor` with the same type as `value` and shape
`[filter_height, filter_width, out_channels, in_channels]`. `filters`'
`in_channels` dimension must match that of `value`. Atrous convolution is
equivalent to standard convolution with upsampled filters with effective
height `filter_height + (filter_height - 1) * (rate - 1)` and effective
width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
inserting `rate - 1` zeros along consecutive elements across the
`filters`' spatial dimensions.
output_shape: A 1-D `Tensor` of shape representing the output shape of the
deconvolution op.
rate: A positive int32. The stride with which we sample input values across
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
name: Optional name for the returned tensor.
Returns:
A `Tensor` with the same type as `value`.
Raises:
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`, or if the `rate` is less
than one, or if the output_shape is not a tensor with 4 elements.
"""
with ops.name_scope(name, "atrous_conv2d_transpose",
[value, filters, output_shape]) as name:
value = ops.convert_to_tensor(value, name="value")
filters = ops.convert_to_tensor(filters, name="filters")
if not value.get_shape()[3].is_compatible_with(filters.get_shape()[3]):
raise ValueError(
"value's input channels does not match filters' input channels, "
"{} != {}".format(value.get_shape()[3], filters.get_shape()[3]))
if rate < 1:
raise ValueError("rate {} cannot be less than one".format(rate))

if rate == 1:
return conv2d_transpose(value,
filters,
output_shape,
strides=[1, 1, 1, 1],
padding=padding,
data_format="NHWC")

output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)):
raise ValueError("output_shape must have shape (4,), got {}"
.format(output_shape_.get_shape()))

if isinstance(output_shape, (list, np.ndarray)):
# output_shape's shape should be == [4] if reached this point.
if not filters.get_shape()[2].is_compatible_with(output_shape[3]):
raise ValueError(
"output_shape does not match filter's output channels, "
"{} != {}".format(output_shape[3], filters.get_shape()[2]))

# We have two padding contributions. The first is used for converting "SAME"
# to "VALID". The second is required so that the height and width of the
# zero-padded value tensor are multiples of rate.

# Padding required to reduce to "VALID" convolution
if padding == "SAME":
# Handle filters whose shape is unknown during graph creation.
if filters.get_shape().is_fully_defined():
filter_shape = filters.get_shape().as_list()
else:
filter_shape = array_ops.shape(filters)
filter_height, filter_width = filter_shape[0], filter_shape[1]

# Spatial dimensions of the filters and the upsampled filters in which we
# introduce (rate - 1) zeros between consecutive filter values.
filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
filter_width_up = filter_width + (filter_width - 1) * (rate - 1)

pad_height = filter_height_up - 1
pad_width = filter_width_up - 1

# When pad_height (pad_width) is odd, we pad more to bottom (right),
# following the same convention as conv2d().
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
elif padding == "VALID":
pad_top = 0
pad_bottom = 0
pad_left = 0
pad_right = 0
else:
raise ValueError("padding must be either VALID or SAME:"
" {}".format(padding))

in_height = output_shape[1] + pad_top + pad_bottom
in_width = output_shape[2] + pad_left + pad_right

# More padding so that rate divides the height and width of the input.
pad_bottom_extra = (rate - in_height % rate) % rate
pad_right_extra = (rate - in_width % rate) % rate

# The paddings argument to space_to_batch is just the extra padding
# component.
space_to_batch_pad = [[0, pad_bottom_extra], [0, pad_right_extra]]

value = array_ops.space_to_batch(input=value,
paddings=space_to_batch_pad,
block_size=rate)

input_sizes = [rate * rate * output_shape[0],
(in_height + pad_bottom_extra) // rate,
(in_width + pad_right_extra) // rate,
output_shape[3]]

value = gen_nn_ops.conv2d_backprop_input(input_sizes=input_sizes,
filter=filters,
out_backprop=value,
strides=[1, 1, 1, 1],
padding="VALID",
data_format="NHWC")

# The crops argument to batch_to_space includes both padding components.
batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra],
[pad_left, pad_right + pad_right_extra]]

return array_ops.batch_to_space(input=value,
crops=batch_to_space_crop,
block_size=rate)


def conv3d_transpose(value,
filter,
output_shape,
Expand Down

0 comments on commit 4347945

Please sign in to comment.