Skip to content

Commit

Permalink
Refactor out inputs verification.
Browse files Browse the repository at this point in the history
Remove duplicate code that is performing the same checks on the `inputs` tensor in every Convolution class.

PiperOrigin-RevId: 183230711
  • Loading branch information
Deepmind authored and diegolascasas committed Jan 29, 2018
1 parent 2a8e6e9 commit 438b320
Showing 1 changed file with 39 additions and 64 deletions.
103 changes: 39 additions & 64 deletions sonnet/python/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,43 @@ def _fill_and_one_pad_stride(stride, n, data_format=DATA_FORMAT_NHWC):
" positive integers of size {}".format(stride, type(stride), n))


def _verify_inputs_dtype(inputs):
"""Verifies that the inputs are of a supported floating point type."""
def _verify_inputs(inputs, channel_index, data_format):
"""Verifies `inputs` is semantically correct.
Args:
inputs: An input tensor provided by the user.
channel_index: The index of the channel dimension.
data_format: The format of the data in `inputs`.
Raises:
base.IncompatibleShapeError: If the shape of `inputs` doesn't match
`data_format`.
base.UnderspecifiedError: If the channel dimension of `inputs` isn't
defined.
TypeError: If input Tensor dtype is not compatible with either
`tf.float16` or `tf.float32`.
"""
# Check shape.
input_shape = tuple(inputs.get_shape().as_list())
if len(input_shape) != len(data_format):
raise base.IncompatibleShapeError((
"Input Tensor must have rank {} corresponding to "
"data_format {}, but instead was {}.").format(
len(input_shape), data_format, input_shape))

# Check type.
if not (tf.float16.is_compatible_with(inputs.dtype) or
tf.float32.is_compatible_with(inputs.dtype)):
raise TypeError(
"Input must have dtype tf.float16 or tf.float32, but dtype was {}"
.format(inputs.dtype))

# Check channel dim.
input_channels = input_shape[channel_index]
if input_channels is None:
raise base.UnderspecifiedError(
"Number of input channels must be known at module build time")


def create_weight_initializer(fan_in_shape, dtype=tf.float32):
"""Returns a default initializer for the weights of a convolutional module."""
Expand Down Expand Up @@ -396,20 +425,9 @@ def _build(self, inputs):
TypeError: If input Tensor dtype is not compatible with either
`tf.float16` or `tf.float32`.
"""
# Handle input whose shape is unknown during graph creation.
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
if len(self._input_shape) != len(self._data_format):
raise base.IncompatibleShapeError((
"Input Tensor must have rank {} corresponding to "
"data_format {}, but instead was {}.").format(
len(self._data_format), self._data_format, self._input_shape))

self._input_channels = self._input_shape[self._channel_index]
if self._input_channels is None:
raise base.UnderspecifiedError(
"Number of input channels must be known at module build time")

_verify_inputs_dtype(inputs)

self._w = self._construct_w(inputs)

Expand Down Expand Up @@ -833,20 +851,9 @@ def _build(self, inputs):
TypeError: If input Tensor dtype is not compatible with either
`tf.float16` or `tf.float32`.
"""
# Handle input whose shape is unknown during graph creation.
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
if len(self._input_shape) != len(self._data_format):
raise base.IncompatibleShapeError((
"Input Tensor must have rank {} corresponding to "
"data_format {}, but instead was {}.").format(
len(self._data_format), self._data_format, self._input_shape))

self._input_channels = self._input_shape[self._channel_index]
if self._input_channels is None:
raise base.UnderspecifiedError(
"Number of input channels must be known at module build time")

_verify_inputs_dtype(inputs)

# First, figure out what the non-(N,C) dims will be.
if self._use_default_output_shape:
Expand Down Expand Up @@ -2043,8 +2050,8 @@ def __init__(self, kernel_shape, stride=1, padding=SAME, use_bias=True,
self._regularizers = util.check_regularizers(
regularizers, self.possible_keys)

self._input_shape = None # Determined in build() from the input.
self._input_channels = None # Determined in build() from the input.
self._data_format = "NHWC"
self._channel_index = 3

@classmethod
def get_possible_initializer_keys(cls, use_bias=True):
Expand Down Expand Up @@ -2074,19 +2081,9 @@ def _build(self, inputs):
TypeError: If input Tensor dtype is not compatible with either
`tf.float16` or `tf.float32`.
"""
# Handle input whose shape is unknown during graph creation.
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
if len(self._input_shape) != 4:
raise base.IncompatibleShapeError((
"Input Tensor must have rank 4, but instead was {}.").format(
self._input_shape))

self._input_channels = self._input_shape[-1]
if self._input_channels is None:
raise base.UnderspecifiedError(
"Number of input channels must be known at module build time")

_verify_inputs_dtype(inputs)
self._input_channels = self._input_shape[self._channel_index]

weight_shape = (
self._kernel_shape[0],
Expand Down Expand Up @@ -2366,20 +2363,9 @@ def _build(self, inputs):
TypeError: If input Tensor dtype is not compatible with either
`tf.float16` or `tf.float32`.
"""
# Handle input whose shape is unknown during graph creation.
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
if len(self._input_shape) != len(self._data_format):
raise base.IncompatibleShapeError((
"Input Tensor must have rank {} corresponding to "
"data_format {}, but instead was {}.").format(
len(self._data_format), self._data_format, self._input_shape))

self._input_channels = self._input_shape[self._channel_index]
if self._input_channels is None:
raise base.UnderspecifiedError(
"Number of input channels must be known at module build time")

_verify_inputs_dtype(inputs)

# For depthwise conv, output_channels = in_channels * channel_multiplier.
# By default, depthwise conv applies a different filter to every input
Expand Down Expand Up @@ -2676,20 +2662,9 @@ def _build(self, inputs):
TypeError: If input Tensor dtype is not compatible with either
`tf.float16` or `tf.float32`.
"""
# Handle input whose shape is unknown during graph creation.
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
if len(self._input_shape) != len(self._data_format):
raise base.IncompatibleShapeError((
"Input Tensor must have rank {} corresponding to "
"data_format {}, but instead was {}.").format(
len(self._data_format), self._data_format, self._input_shape))

self._input_channels = self._input_shape[self._channel_index]
if self._input_channels is None:
raise base.UnderspecifiedError(
"Number of input channels must be known at module build time")

_verify_inputs_dtype(inputs)

depthwise_weight_shape = (self._kernel_shape[0], self._kernel_shape[1],
self._input_channels, self._channel_multiplier)
Expand Down

0 comments on commit 438b320

Please sign in to comment.