Skip to content

Commit

Permalink
Enable ConvND modules to use float64 tensors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 252046103
  • Loading branch information
Deepmind authored and fvioladm committed Jul 5, 2019
1 parent f6577ed commit c1d28b8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 55 deletions.
53 changes: 27 additions & 26 deletions sonnet/python/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _verify_inputs(inputs, channel_index, data_format):
base.UnderspecifiedError: If the channel dimension of `inputs` isn't
defined.
TypeError: If input Tensor dtype is not compatible with either
`tf.float16`, `tf.bfloat16` or `tf.float32`.
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
"""
# Check shape.
input_shape = tuple(inputs.get_shape().as_list())
Expand All @@ -280,10 +280,11 @@ def _verify_inputs(inputs, channel_index, data_format):
# Check type.
if not (tf.float16.is_compatible_with(inputs.dtype) or
tf.bfloat16.is_compatible_with(inputs.dtype) or
tf.float32.is_compatible_with(inputs.dtype)):
tf.float32.is_compatible_with(inputs.dtype) or
tf.float64.is_compatible_with(inputs.dtype)):
raise TypeError(
"Input must have dtype tf.float16, tf.bfloat16 or tf.float32, "
"but dtype was {}".format(inputs.dtype))
"Input must have dtype tf.float16, tf.bfloat16, tf.float32 or "
"tf.float64, but dtype was {}".format(inputs.dtype))

# Check channel dim.
input_channels = input_shape[channel_index]
Expand Down Expand Up @@ -528,7 +529,7 @@ def _build(self, inputs):
Args:
inputs: A ND Tensor of the same rank as `data_format`, and either of types
`tf.float16`, `tf.bfloat16` or `tf.float32`.
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ...,
Expand All @@ -545,7 +546,7 @@ def _build(self, inputs):
base.IncompatibleShapeError: If a mask is present and its shape is
incompatible with the shape of the weights.
TypeError: If input Tensor dtype is not compatible with either
`tf.float16`, `tf.bfloat16` or `tf.float32`.
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
"""
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
Expand Down Expand Up @@ -586,7 +587,7 @@ def _pad_input(self, inputs):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
inputs: The `inputs` argument that has had any required padding added.
Expand Down Expand Up @@ -630,7 +631,7 @@ def _apply_conv(self, inputs, w):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
w: A weight matrix of the same type as `inputs`.
Returns:
Expand All @@ -649,7 +650,7 @@ def _construct_w(self, inputs):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
w: A weight matrix of the same type as `inputs`.
Expand Down Expand Up @@ -1001,11 +1002,11 @@ def _build(self, inputs):
Args:
inputs: A Tensor of shape `data_format` and of type
`tf.float16`, `tf.bfloat16` or `tf.float32`.
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
A Tensor of shape `data_format` and of type `tf.float16`, `tf.bfloat16`
or `tf.float32`.
A Tensor of shape `data_format` and of type `tf.float16`, `tf.bfloat16`,
`tf.float32` or `tf.float64`.
Raises:
ValueError: If connecting the module into the graph any time after the
Expand All @@ -1018,7 +1019,7 @@ def _build(self, inputs):
base.IncompatibleShapeError: If `output_shape` is an iterable and is not
in the format `(out_height, out_width)`.
TypeError: If input Tensor dtype is not compatible with either
`tf.float16`, `tf.bfloat16` or `tf.float32`.
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
"""
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
Expand Down Expand Up @@ -1099,7 +1100,7 @@ def _construct_w(self, inputs):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
w: A weight matrix of the same type as `inputs`.
Expand Down Expand Up @@ -1130,7 +1131,7 @@ def _infer_all_output_dims(self, inputs):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
output_shape: A tensor of shape (`batch_size`, `conv_output_shape`).
Expand Down Expand Up @@ -1164,10 +1165,10 @@ def _recover_shape_information(self, inputs, outputs):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
outputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`. The output of `inputs` from a transpose
convolution op.
`tf.bfloat16`, `tf.float32` or `tf.float64`. The output of `inputs`
from a transpose convolution op.
Returns:
outputs: The passed-in `outputs` with all shape information filled in.
Expand Down Expand Up @@ -2274,7 +2275,7 @@ def _construct_w(self, inputs):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
w: A weight matrix of the same type as `inputs` and of shape
Expand All @@ -2299,7 +2300,7 @@ def _apply_conv(self, inputs, w):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
w: A weight matrix of the same type as `inputs`.
Returns:
Expand Down Expand Up @@ -2440,7 +2441,7 @@ def _construct_w(self, inputs):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
w: A weight matrix of the same type as `inputs` and of shape
Expand Down Expand Up @@ -2471,7 +2472,7 @@ def _apply_conv(self, inputs, w):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
w: A weight matrix of the same type as `inputs`.
Returns:
Expand Down Expand Up @@ -2634,7 +2635,7 @@ def _construct_w(self, inputs):
Args:
inputs: A 4D Tensor of shape:
[batch_size, input_height, input_width, input_channels]
and of type `tf.float16`, `tf.bfloat16` or `tf.float32`.
and of type `tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
A tuple of two 4D Tensors, each with the same dtype as `inputs`:
Expand Down Expand Up @@ -2681,7 +2682,7 @@ def _apply_conv(self, inputs, w):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
w: A tuple of weight matrices of the same type as `inputs`, the first
being the depthwise weight matrix, and the second being the pointwise
weight matrix.
Expand Down Expand Up @@ -2860,7 +2861,7 @@ def _construct_w(self, inputs):
Args:
inputs: A 4D Tensor of shape:
[batch_size, input_height, input_width, input_channels]
and of type `tf.float16`, `tf.bfloat16` or `tf.float32`.
and of type `tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
A tuple of two 4D Tensors, each with the same dtype as `inputs`:
Expand Down Expand Up @@ -2907,7 +2908,7 @@ def _apply_conv(self, inputs, w):
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16` or `tf.float32`.
`tf.bfloat16`, `tf.float32` or `tf.float64`.
w: A tuple of weight matrices of the same type as `inputs`, the first
being the depthwise weight matrix, and the second being the pointwise
weight matrix.
Expand Down
58 changes: 29 additions & 29 deletions sonnet/python/modules/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sonnet.python.modules import conv
import tensorflow as tf

from tensorflow.python.ops import variables
from tensorflow.python.ops import variables # pylint: disable=g-direct-tensorflow-import


def create_constant_initializers(w, b, use_bias):
Expand Down Expand Up @@ -281,9 +281,9 @@ def testPartitioners(self, module, num_input_dims, module_kwargs):

self.assertEqual(convolution_t.partitioners, convolution.partitioners)

@parameterized.parameters(*itertools.product(modules,
(True, False),
(tf.float16, tf.float32)))
@parameterized.parameters(
*itertools.product(modules, (True, False),
(tf.float16, tf.float32, tf.float64)))
def testVariables(self, module_info, use_bias, dtype):
"""The correct number of variables are created."""
module, num_input_dims, module_kwargs = module_info
Expand Down Expand Up @@ -619,7 +619,7 @@ def testInputTypeError(self, use_bias):
initializers=create_constant_initializers(
1.0, 1.0, use_bias))

for dtype in (tf.uint32, tf.float64):
for dtype in (tf.uint32, tf.uint64):
x = tf.constant(np.ones([1, 5, 5, 1]), dtype=dtype)
err = "Input must have dtype tf.float.*"
with self.assertRaisesRegexp(TypeError, err):
Expand Down Expand Up @@ -879,9 +879,9 @@ def testClone(self):
self.assertEqual(net.output_channels, clone1.output_channels)
self.assertEqual(net.module_name + "_clone", clone1.module_name)
self.assertEqual("clone2", clone2.module_name)
self.assertEqual(len(all_vars), 3*len(net_vars))
self.assertEqual(len(net_vars), len(clone1_vars))
self.assertEqual(len(net_vars), len(clone2_vars))
self.assertLen(all_vars, 3*len(net_vars))
self.assertLen(net_vars, len(clone1_vars))
self.assertLen(net_vars, len(clone2_vars))
self.assertEqual(net_out.get_shape().as_list(),
clone1_out.get_shape().as_list())
self.assertEqual(net_out.get_shape().as_list(),
Expand Down Expand Up @@ -1254,7 +1254,7 @@ def testTransposeNHWC(self, use_bias, use_output_shape):
# transpose convolution!)
err = "Variables in conv2_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv2_transpose.output_shape, conv2.input_shape)
_ = conv2.input_shape

@parameterized.named_parameters(
("WithBiasWithOutputShape", True, True),
Expand Down Expand Up @@ -1307,7 +1307,7 @@ def testTransposeNCHW(self, use_bias, use_output_shape):
# transpose convolution!)
err = "Variables in conv2_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv2_transpose.output_shape, conv2.input_shape)
_ = conv2.input_shape


class Conv1DTest(parameterized.TestCase, tf.test.TestCase):
Expand Down Expand Up @@ -1481,7 +1481,7 @@ def testInputTypeError(self, use_bias):
initializers=create_constant_initializers(
1.0, 1.0, use_bias))

for dtype in (tf.uint32, tf.float64):
for dtype in (tf.uint32, tf.uint64):
x = tf.constant(np.ones([1, 5, 1]), dtype=dtype)
err = "Input must have dtype tf.float.*"
with self.assertRaisesRegexp(TypeError, err):
Expand Down Expand Up @@ -1763,9 +1763,9 @@ def testClone(self):
self.assertEqual(net.output_channels, clone1.output_channels)
self.assertEqual(net.module_name + "_clone", clone1.module_name)
self.assertEqual("clone2", clone2.module_name)
self.assertEqual(len(all_vars), 3*len(net_vars))
self.assertEqual(len(net_vars), len(clone1_vars))
self.assertEqual(len(net_vars), len(clone2_vars))
self.assertLen(all_vars, 3*len(net_vars))
self.assertLen(net_vars, len(clone1_vars))
self.assertLen(net_vars, len(clone2_vars))
self.assertEqual(net_out.get_shape().as_list(),
clone1_out.get_shape().as_list())
self.assertEqual(net_out.get_shape().as_list(),
Expand Down Expand Up @@ -1973,7 +1973,7 @@ def testInputTypeError(self, batch_size, in_length, in_channels, out_channels,
name="conv1",
use_bias=use_bias)

for dtype in (tf.uint32, tf.float64):
for dtype in (tf.uint32, tf.uint64):
x = tf.constant(np.ones([batch_size, in_length,
in_channels]), dtype=dtype)
err = "Input must have dtype tf.float.*"
Expand Down Expand Up @@ -2065,7 +2065,7 @@ def testTransposeNWC(self, batch_size, in_length, in_channels, out_channels,
# transpose convolution!)
err = "Variables in conv1_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv1_transpose.output_shape, conv1.input_shape)
_ = conv1.input_shape

@parameterized.parameters(
*zip(batch_size, in_length, in_channels, out_channels, kernel_shape,
Expand Down Expand Up @@ -2112,7 +2112,7 @@ def testTransposeNCW(self, batch_size, in_length, in_channels, out_channels,
# transpose convolution!)
err = "Variables in conv1_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv1_transpose.output_shape, conv1.input_shape)
_ = conv1.input_shape

def testInitializerMutation(self):
"""Test that initializers are not mutated."""
Expand Down Expand Up @@ -2281,9 +2281,9 @@ def testClone(self):
self.assertEqual(net.output_channels, clone1.output_channels)
self.assertEqual(net.module_name + "_clone", clone1.module_name)
self.assertEqual("clone2", clone2.module_name)
self.assertEqual(len(all_vars), 3*len(net_vars))
self.assertEqual(len(net_vars), len(clone1_vars))
self.assertEqual(len(net_vars), len(clone2_vars))
self.assertLen(all_vars, 3*len(net_vars))
self.assertLen(net_vars, len(clone1_vars))
self.assertLen(net_vars, len(clone2_vars))
self.assertEqual(net_out.get_shape().as_list(),
clone1_out.get_shape().as_list())
self.assertEqual(net_out.get_shape().as_list(),
Expand Down Expand Up @@ -2481,7 +2481,7 @@ def testInputTypeError(self, use_bias):
use_bias=use_bias,
initializers=create_constant_initializers(1.0, 1.0, use_bias))

for dtype in (tf.uint32, tf.float64):
for dtype in (tf.uint32, tf.uint64):
x = tf.constant(np.ones([1, 5, 5, 1]), dtype=dtype)
err = "Input must have dtype tf.float.*"
with self.assertRaisesRegexp(TypeError, err):
Expand Down Expand Up @@ -2820,7 +2820,7 @@ def testInputTypeError(self, use_bias):
initializers=create_separable_constant_initializers(
1.0, 1.0, 1.0, use_bias))

for dtype in (tf.uint32, tf.float64):
for dtype in (tf.uint32, tf.uint64):
x = tf.constant(np.ones([1, 5, 5, 1]), dtype=dtype)
err = "Input must have dtype tf.float.*"
with self.assertRaisesRegexp(TypeError, err):
Expand Down Expand Up @@ -3251,7 +3251,7 @@ def testInputTypeError(self, use_bias):
initializers=create_separable_constant_initializers(
1.0, 1.0, 1.0, use_bias))

for dtype in (tf.uint32, tf.float64):
for dtype in (tf.uint32, tf.uint64):
x = tf.constant(np.ones([1, 5, 1]), dtype=dtype)
err = "Input must have dtype tf.float.*"
with self.assertRaisesRegexp(TypeError, err):
Expand Down Expand Up @@ -3681,7 +3681,7 @@ def testInputTypeError(self):
"b": tf.constant_initializer(1.0),
})

for dtype in (tf.uint32, tf.float64):
for dtype in (tf.uint32, tf.uint64):
x = tf.constant(np.ones([1, 5, 5, 5, 1]), dtype=dtype)
self.assertRaisesRegexp(TypeError, "Input must have dtype tf.float.*",
conv1, x)
Expand Down Expand Up @@ -4067,9 +4067,9 @@ def testClone(self):
self.assertEqual(net.output_channels, clone1.output_channels)
self.assertEqual(net.module_name + "_clone", clone1.module_name)
self.assertEqual("clone2", clone2.module_name)
self.assertEqual(len(all_vars), 3*len(net_vars))
self.assertEqual(len(net_vars), len(clone1_vars))
self.assertEqual(len(net_vars), len(clone2_vars))
self.assertLen(all_vars, 3*len(net_vars))
self.assertLen(net_vars, len(clone1_vars))
self.assertLen(net_vars, len(clone2_vars))
self.assertEqual(net_out.get_shape().as_list(),
clone1_out.get_shape().as_list())
self.assertEqual(net_out.get_shape().as_list(),
Expand Down Expand Up @@ -4294,7 +4294,7 @@ def testTransposeNDHWC(self, use_bias):
# transpose convolution!)
err = "Variables in conv3_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv3_transpose.output_shape, conv3.input_shape)
_ = conv3.input_shape

@parameterized.named_parameters(
("WithBias", True),
Expand Down Expand Up @@ -4340,7 +4340,7 @@ def testTransposeNCDHW(self, use_bias):
# transpose convolution!)
err = "Variables in conv3_transpose_transpose not instantiated yet"
with self.assertRaisesRegexp(snt.NotConnectedError, err):
self.assertEqual(conv3_transpose.output_shape, conv3.input_shape)
_ = conv3.input_shape

if __name__ == "__main__":
tf.test.main()

0 comments on commit c1d28b8

Please sign in to comment.