diff --git a/sonnet/examples/module_with_build_args.py b/sonnet/examples/module_with_build_args.py index e647cf2e..dfe3d73b 100644 --- a/sonnet/examples/module_with_build_args.py +++ b/sonnet/examples/module_with_build_args.py @@ -72,7 +72,7 @@ def main(unused_argv): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - for _ in xrange(100): + for _ in range(100): sess.run(train_step) # Check that evaluating train_model_outputs twice returns the same value. train_outputs, train_outputs_2 = sess.run([train_model_outputs, diff --git a/sonnet/examples/rnn_shakespeare.py b/sonnet/examples/rnn_shakespeare.py index b4a98769..2fb9a19a 100644 --- a/sonnet/examples/rnn_shakespeare.py +++ b/sonnet/examples/rnn_shakespeare.py @@ -151,7 +151,7 @@ def generate_string(self, initial_logits, initial_state, sequence_length): current_state = initial_state generated_letters = [] - for _ in xrange(sequence_length): + for _ in range(sequence_length): # Sample a character index from distribution. char_index = tf.squeeze(tf.multinomial(current_logits, 1)) char_one_hot = tf.one_hot(char_index, self._output_size, 1.0, 0.0) @@ -270,7 +270,7 @@ def train(num_training_iterations, report_interval, start_iteration = sess.run(global_step) - for train_iteration in xrange(start_iteration, num_training_iterations): + for train_iteration in range(start_iteration, num_training_iterations): if (train_iteration + 1) % report_interval == 0: train_loss_v, valid_loss_v, _ = sess.run( (train_loss, valid_loss, train_step)) diff --git a/sonnet/python/modules/base.py b/sonnet/python/modules/base.py index 84084a31..c18dd061 100644 --- a/sonnet/python/modules/base.py +++ b/sonnet/python/modules/base.py @@ -26,7 +26,6 @@ import abc import collections -import types # Dependency imports import six from sonnet.python.modules import util @@ -130,7 +129,7 @@ def __init__(self, name=None): ValueError: If name is not specified. """ - if name is None or not isinstance(name, types.StringTypes): + if name is None or not isinstance(name, six.string_types): raise ValueError("Name must be a string.") self._connected_subgraphs = [] diff --git a/sonnet/python/modules/base_test.py b/sonnet/python/modules/base_test.py index 0242ae79..83358e19 100644 --- a/sonnet/python/modules/base_test.py +++ b/sonnet/python/modules/base_test.py @@ -22,6 +22,7 @@ from functools import partial import numpy as np +import six from sonnet.python.modules import base import tensorflow as tf @@ -62,8 +63,10 @@ def testInitializerKeys(self): self.assertEqual(keys, {"foo", "bar"}) keys = ModuleWithNoInitializerKeys.get_possible_initializer_keys() self.assertEqual(keys, set()) + msg = ("missing 1 required positional argument" if six.PY3 + else "takes exactly 2 arguments") self.assertRaisesRegexp( - TypeError, "takes exactly 2 arguments", + TypeError, msg, ModuleWithCustomInitializerKeys.get_possible_initializer_keys) keys = ModuleWithCustomInitializerKeys.get_possible_initializer_keys(True) self.assertEqual(keys, {"foo"}) @@ -146,7 +149,7 @@ def testFunctionType(self): with self.assertRaises(TypeError) as cm: base.Module(build="not_a_function") - self.assertEqual(cm.exception.message, "Input 'build' must be callable.") + self.assertEqual(str(cm.exception), "Input 'build' must be callable.") def testSharing(self): batch_size = 3 diff --git a/sonnet/python/modules/basic.py b/sonnet/python/modules/basic.py index 91ef8e4a..3ac2ec1e 100644 --- a/sonnet/python/modules/basic.py +++ b/sonnet/python/modules/basic.py @@ -27,6 +27,7 @@ # Dependency imports import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from sonnet.python.modules import base from sonnet.python.modules import util from sonnet.python.ops import nest diff --git a/sonnet/python/modules/basic_rnn_test.py b/sonnet/python/modules/basic_rnn_test.py index f076f1bf..14e97507 100644 --- a/sonnet/python/modules/basic_rnn_test.py +++ b/sonnet/python/modules/basic_rnn_test.py @@ -22,6 +22,7 @@ # Dependency imports import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin import sonnet as snt from sonnet.testing import parameterized import tensorflow as tf diff --git a/sonnet/python/modules/basic_test.py b/sonnet/python/modules/basic_test.py index 815a751b..b1c32bda 100644 --- a/sonnet/python/modules/basic_test.py +++ b/sonnet/python/modules/basic_test.py @@ -23,6 +23,7 @@ # Dependency imports import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin import sonnet as snt from sonnet.testing import parameterized import tensorflow as tf diff --git a/sonnet/python/modules/batch_norm.py b/sonnet/python/modules/batch_norm.py index eb0cc586..cf8723ad 100644 --- a/sonnet/python/modules/batch_norm.py +++ b/sonnet/python/modules/batch_norm.py @@ -306,7 +306,7 @@ def _infer_fused_data_format(self, input_batch): # Reduce over the second dimension. return "NCHW" else: - raise ValueError("Invalid axis option {:s}. This does not correspond to" + raise ValueError("Invalid axis option {}. This does not correspond to" " either the NHWC format (0, 1, 2) or the NCHW " "(0, 2, 3).".format(axis)) @@ -439,23 +439,23 @@ def _build(self, input_batch, is_training=True, test_local_stats=True): if self._axis is not None: if len(self._axis) > len(input_shape): raise base.IncompatibleShapeError( - "Too many indices specified in axis: len({:s}) > len({:s}).".format( + "Too many indices specified in axis: len({}) > len({}).".format( self._axis, input_shape)) if max(self._axis) >= len(input_shape): raise base.IncompatibleShapeError( "One or more index in axis is too large for " - "input shape: {:s} >= {:d}.".format(self._axis, len(input_shape))) + "input shape: {} >= {:d}.".format(self._axis, len(input_shape))) if min(self._axis) < 0: raise base.IncompatibleShapeError( - "Indices in axis must be non-negative: {:s} < 0.".format( + "Indices in axis must be non-negative: {} < 0.".format( self._axis)) axis = self._axis else: # Reduce over all dimensions except the last. - axis = range(len(input_shape))[:-1] + axis = tuple(range(len(input_shape))[:-1]) # See following for important note on accuracy for dtype=tf.float16 # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L63 diff --git a/sonnet/python/modules/block_matrix.py b/sonnet/python/modules/block_matrix.py index 1e459d9e..b4d5f90e 100644 --- a/sonnet/python/modules/block_matrix.py +++ b/sonnet/python/modules/block_matrix.py @@ -19,6 +19,7 @@ from __future__ import print_function # Dependency imports +from six.moves import xrange # pylint: disable=redefined-builtin from sonnet.python.modules import base import tensorflow as tf diff --git a/sonnet/python/modules/conv_test.py b/sonnet/python/modules/conv_test.py index f4d2d52a..56ff907e 100644 --- a/sonnet/python/modules/conv_test.py +++ b/sonnet/python/modules/conv_test.py @@ -753,7 +753,7 @@ def testMaskErrorInvalidRank(self): with self.assertRaises(snt.Error) as cm: snt.Conv2D(output_channels=4, kernel_shape=3, mask=mask) self.assertEqual( - cm.exception.message, + str(cm.exception), "Invalid mask rank: {}".format(mask.ndim)) def testMaskErrorInvalidType(self): @@ -763,7 +763,7 @@ def testMaskErrorInvalidType(self): with self.assertRaises(TypeError) as cm: snt.Conv2D(output_channels=4, kernel_shape=3, mask=mask) self.assertEqual( - cm.exception.message, "Invalid type for mask: {}".format(type(mask))) + str(cm.exception), "Invalid type for mask: {}".format(type(mask))) def testMaskErrorIncompatibleRank2(self): """Errors are thrown for incompatible rank 2 mask.""" @@ -772,8 +772,8 @@ def testMaskErrorIncompatibleRank2(self): x = tf.constant(0.0, shape=(2, 8, 8, 6)) with self.assertRaises(snt.Error) as cm: snt.Conv2D(output_channels=4, kernel_shape=5, mask=mask)(x) - self.assertEqual( - cm.exception.message, "Invalid mask shape: {}".format(mask.shape)) + self.assertTrue(str(cm.exception).startswith( + "Invalid mask shape: {}".format(mask.shape))) def testMaskErrorIncompatibleRank4(self): """Errors are thrown for incompatible rank 4 mask.""" @@ -782,8 +782,8 @@ def testMaskErrorIncompatibleRank4(self): x = tf.constant(0.0, shape=(2, 8, 8, 6)) with self.assertRaises(snt.Error) as cm: snt.Conv2D(output_channels=4, kernel_shape=5, mask=mask)(x) - self.assertEqual( - cm.exception.message, "Invalid mask shape: {}".format(mask.shape)) + self.assertTrue(str(cm.exception).startswith( + "Invalid mask shape: {}".format(mask.shape))) class Conv2DTransposeTest(parameterized.ParameterizedTestCase, diff --git a/sonnet/python/modules/nets/convnet.py b/sonnet/python/modules/nets/convnet.py index ec61769f..8393491e 100644 --- a/sonnet/python/modules/nets/convnet.py +++ b/sonnet/python/modules/nets/convnet.py @@ -20,6 +20,7 @@ import collections +from six.moves import xrange # pylint: disable=redefined-builtin from sonnet.python.modules import base from sonnet.python.modules import batch_norm from sonnet.python.modules import conv diff --git a/sonnet/python/modules/nets/mlp.py b/sonnet/python/modules/nets/mlp.py index c9febd29..3f2f4cab 100644 --- a/sonnet/python/modules/nets/mlp.py +++ b/sonnet/python/modules/nets/mlp.py @@ -20,6 +20,7 @@ import collections +from six.moves import xrange # pylint: disable=redefined-builtin from sonnet.python.modules import base from sonnet.python.modules import basic from sonnet.python.modules import util diff --git a/sonnet/python/modules/rnn_core.py b/sonnet/python/modules/rnn_core.py index 965a8a5f..fdf0d653 100644 --- a/sonnet/python/modules/rnn_core.py +++ b/sonnet/python/modules/rnn_core.py @@ -29,6 +29,7 @@ # Dependency imports import six +from six.moves import xrange # pylint: disable=redefined-builtin from sonnet.python.modules import base from sonnet.python.modules import basic import tensorflow as tf diff --git a/sonnet/python/modules/sequential_test.py b/sonnet/python/modules/sequential_test.py index d6256042..118a1967 100644 --- a/sonnet/python/modules/sequential_test.py +++ b/sonnet/python/modules/sequential_test.py @@ -19,6 +19,7 @@ from __future__ import print_function # Dependency imports +import six import sonnet as snt import tensorflow as tf @@ -60,7 +61,10 @@ def module1(a, b): def module2(a, b, c): return a, b, c - err_str = r"module2\(\) takes exactly 3 arguments \(2 given\)" + if six.PY3: + err_str = r"module2\(\) missing 1 required positional argument: 'c'" + else: + err_str = r"module2\(\) takes exactly 3 arguments \(2 given\)" with self.assertRaisesRegexp(TypeError, err_str): _, _ = snt.Sequential([module1, module2], name="seq2")(1, 2) diff --git a/sonnet/python/modules/spatial_transformer.py b/sonnet/python/modules/spatial_transformer.py index 74dabbaf..a92088b9 100644 --- a/sonnet/python/modules/spatial_transformer.py +++ b/sonnet/python/modules/spatial_transformer.py @@ -23,6 +23,7 @@ # Dependency imports import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from sonnet.python.modules import base from sonnet.python.modules import basic import tensorflow as tf @@ -419,7 +420,7 @@ def _affine_grid_warper_inverse(inputs): index = iter(range(6)) def get_variable(constraint): if constraint is None: - i = index.next() + i = next(index) return inputs[:, i:i+1] else: return tf.fill(constant_shape, tf.constant(constraint, diff --git a/sonnet/python/modules/util.py b/sonnet/python/modules/util.py index 20cc01bd..997b4e38 100644 --- a/sonnet/python/modules/util.py +++ b/sonnet/python/modules/util.py @@ -21,6 +21,7 @@ import re # Dependency imports +import six import tensorflow as tf @@ -82,7 +83,7 @@ def _check_nested_callables(dictionary, object_name): TypeError: If the dictionary contains something that is not either a dictionary or a callable. """ - for key, entry in dictionary.iteritems(): + for key, entry in six.iteritems(dictionary): if isinstance(entry, dict): _check_nested_callables(entry, object_name) elif not callable(entry): @@ -311,7 +312,7 @@ def get_saver(scope, collections=(tf.GraphKeys.GLOBAL_VARIABLES,), def has_variable_scope(obj): """Determines whether the given object has a variable scope.""" - return hasattr(obj, "variable_scope") or "variable_scope" in dir(obj) + return "variable_scope" in dir(obj) def _format_table(rows): diff --git a/sonnet/python/ops/nest.py b/sonnet/python/ops/nest.py index 4516895a..21afaa8c 100644 --- a/sonnet/python/ops/nest.py +++ b/sonnet/python/ops/nest.py @@ -290,7 +290,7 @@ def map(fn_or_op, *inputs): # pylint: disable=redefined-builtin def _sorted(dict_): """Returns a sorted list from the dict, with error if keys not sortable.""" try: - return sorted(dict_.iterkeys()) + return sorted(six.iterkeys(dict_)) except TypeError: raise TypeError("nest only supports dicts with sortable keys.") @@ -307,7 +307,7 @@ def _iterable_like(instance, args): `args` with the type of `instance`. """ if isinstance(instance, collections.OrderedDict): - return collections.OrderedDict(zip(instance.iterkeys(), args)) + return collections.OrderedDict(zip(six.iterkeys(instance), args)) elif isinstance(instance, dict): return dict(zip(_sorted(instance), args)) elif (isinstance(instance, tuple) and diff --git a/sonnet/python/ops/nest_test.py b/sonnet/python/ops/nest_test.py index 32319a88..546378ab 100644 --- a/sonnet/python/ops/nest_test.py +++ b/sonnet/python/ops/nest_test.py @@ -24,9 +24,12 @@ # Dependency imports import numpy as np +import six from sonnet.python.ops import nest import tensorflow as tf +typekw = "class" if six.PY3 else "type" + class NestTest(tf.test.TestCase): @@ -35,7 +38,7 @@ def testAssertShallowStructure(self): inp_abc = ["a", "b", "c"] with self.assertRaises(ValueError) as cm: nest.assert_shallow_structure(inp_abc, inp_ab) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "The two structures don't have the same sequence length. " "Input structure has length 2, while shallow structure " "has length 3.") @@ -44,10 +47,10 @@ def testAssertShallowStructure(self): inp_ab2 = [[1, 1], [2, 2]] with self.assertRaises(TypeError) as cm: nest.assert_shallow_structure(inp_ab2, inp_ab1) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "The two structures don't have the same sequence type. " - "Input structure has type , while shallow " - "structure has type .") + "Input structure has type <{0} 'tuple'>, while shallow " + "structure has type <{0} 'list'>.".format(typekw)) def testFlattenUpTo(self): # Normal application (Example 1). @@ -123,9 +126,9 @@ def testFlattenUpTo(self): with self.assertRaises(TypeError) as cm: flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(cm.exception.message, - "If shallow structure is a sequence, input must also " - "be a sequence. Input has type: .") + self.assertEqual(str(cm.exception), + "If shallow structure is a sequence, input must also be " + "a sequence. Input has type: <{} 'str'>.".format(typekw)) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" @@ -133,9 +136,9 @@ def testFlattenUpTo(self): with self.assertRaises(TypeError) as cm: flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(cm.exception.message, - "If shallow structure is a sequence, input must also " - "be a sequence. Input has type: .") + self.assertEqual(str(cm.exception), + "If shallow structure is a sequence, input must also be " + "a sequence. Input has type: <{} 'str'>.".format(typekw)) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. @@ -144,9 +147,9 @@ def testFlattenUpTo(self): with self.assertRaises(TypeError) as cm: flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(cm.exception.message, - "If shallow structure is a sequence, input must also " - "be a sequence. Input has type: .") + self.assertEqual(str(cm.exception), + "If shallow structure is a sequence, input must also be " + "a sequence. Input has type: <{} 'int'>.".format(typekw)) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 @@ -154,9 +157,9 @@ def testFlattenUpTo(self): with self.assertRaises(TypeError) as cm: flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(cm.exception.message, - "If shallow structure is a sequence, input must also " - "be a sequence. Input has type: .") + self.assertEqual(str(cm.exception), + "If shallow structure is a sequence, input must also be " + "a sequence. Input has type: <{} 'int'>.".format(typekw)) self.assertEqual(flattened_shallow_tree, shallow_tree) def testMapUpTo(self): diff --git a/sonnet/python/ops/resampler_test.py b/sonnet/python/ops/resampler_test.py index 5a298a6c..186f0160 100644 --- a/sonnet/python/ops/resampler_test.py +++ b/sonnet/python/ops/resampler_test.py @@ -22,6 +22,7 @@ # Dependency imports import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin import sonnet as snt from sonnet.testing import parameterized diff --git a/sonnet/testing/parameterized/parameterized.py b/sonnet/testing/parameterized/parameterized.py index d2e432e7..037b8877 100644 --- a/sonnet/testing/parameterized/parameterized.py +++ b/sonnet/testing/parameterized/parameterized.py @@ -147,6 +147,7 @@ def testSumIsZero(self, arg): import unittest import uuid +import six from tensorflow.python.platform import googletest ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>') @@ -167,13 +168,13 @@ def _StrClass(cls): def _NonStringIterable(obj): return (isinstance(obj, collections.Iterable) and not - isinstance(obj, basestring)) + isinstance(obj, six.string_types)) def _FormatParameterList(testcase_params): if isinstance(testcase_params, collections.Mapping): return ', '.join('%s=%s' % (argname, _CleanRepr(value)) - for argname, value in testcase_params.iteritems()) + for argname, value in six.iteritems(testcase_params)) elif _NonStringIterable(testcase_params): return ', '.join(map(_CleanRepr, testcase_params)) else: @@ -265,7 +266,7 @@ def _ModifyClass(class_object, testcases, naming_type): 'Cannot add parameters to %s,' ' which already has parameterized methods.' % (class_object,)) class_object._id_suffix = id_suffix = {} - for name, obj in class_object.__dict__.items(): + for name, obj in list(six.iteritems(class_object.__dict__)): if (name.startswith(unittest.TestLoader.testMethodPrefix) and isinstance(obj, types.FunctionType)): delattr(class_object, name) @@ -273,7 +274,7 @@ def _ModifyClass(class_object, testcases, naming_type): _UpdateClassDictForParamTestCase( methods, id_suffix, name, _ParameterizedTestIter(obj, testcases, naming_type)) - for name, meth in methods.iteritems(): + for name, meth in six.iteritems(methods): setattr(class_object, name, meth) @@ -353,7 +354,7 @@ class TestGeneratorMetaclass(type): def __new__(mcs, class_name, bases, dct): dct['_id_suffix'] = id_suffix = {} - for name, obj in dct.items(): + for name, obj in list(six.iteritems(dct)): if (name.startswith(unittest.TestLoader.testMethodPrefix) and _NonStringIterable(obj)): iterator = iter(obj) @@ -385,9 +386,9 @@ def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator): id_suffix[new_name] = getattr(func, '__x_extra_id__', '') -class ParameterizedTestCase(googletest.TestCase): +class ParameterizedTestCase( + six.with_metaclass(TestGeneratorMetaclass, googletest.TestCase)): """Base class for test cases using the Parameters decorator.""" - __metaclass__ = TestGeneratorMetaclass def _OriginalName(self): return self._testMethodName.split(_SEPARATOR)[0] diff --git a/sonnet/testing/parameterized/parameterized_test.py b/sonnet/testing/parameterized/parameterized_test.py index a35af9ad..98977480 100644 --- a/sonnet/testing/parameterized/parameterized_test.py +++ b/sonnet/testing/parameterized/parameterized_test.py @@ -19,6 +19,8 @@ import unittest # Dependency imports +import six +from six.moves import xrange # pylint: disable=redefined-builtin from sonnet.testing import parameterized from tensorflow.python.platform import googletest @@ -355,7 +357,9 @@ def testSomething(unused_self, unused_obj): # pylint: disable=invalid-name expected_testcases = [1, 2, 3, 4, 5, 6] self.assertTrue(hasattr(testSomething, 'testcases')) - self.assertItemsEqual(expected_testcases, testSomething.testcases) + assert_items_equal = (self.assertCountEqual if six.PY3 + else self.assertItemsEqual) + assert_items_equal(expected_testcases, testSomething.testcases) def testChainedDecorator(self): ts = unittest.makeSuite(self.ChainedTests)