Skip to content

Commit

Permalink
fixes for python3 compatibility
Browse files Browse the repository at this point in the history
GitOrigin-RevId=083515c1b58437e98c4ebd5935bd791d31a3a007
PiperOrigin-RevId: 154559781
  • Loading branch information
bfredl authored and adria-p committed Apr 28, 2017
1 parent 697c20f commit 578e336
Show file tree
Hide file tree
Showing 21 changed files with 72 additions and 48 deletions.
2 changes: 1 addition & 1 deletion sonnet/examples/module_with_build_args.py
Expand Up @@ -72,7 +72,7 @@ def main(unused_argv):

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in xrange(100):
for _ in range(100):
sess.run(train_step)
# Check that evaluating train_model_outputs twice returns the same value.
train_outputs, train_outputs_2 = sess.run([train_model_outputs,
Expand Down
4 changes: 2 additions & 2 deletions sonnet/examples/rnn_shakespeare.py
Expand Up @@ -151,7 +151,7 @@ def generate_string(self, initial_logits, initial_state, sequence_length):
current_state = initial_state

generated_letters = []
for _ in xrange(sequence_length):
for _ in range(sequence_length):
# Sample a character index from distribution.
char_index = tf.squeeze(tf.multinomial(current_logits, 1))
char_one_hot = tf.one_hot(char_index, self._output_size, 1.0, 0.0)
Expand Down Expand Up @@ -270,7 +270,7 @@ def train(num_training_iterations, report_interval,

start_iteration = sess.run(global_step)

for train_iteration in xrange(start_iteration, num_training_iterations):
for train_iteration in range(start_iteration, num_training_iterations):
if (train_iteration + 1) % report_interval == 0:
train_loss_v, valid_loss_v, _ = sess.run(
(train_loss, valid_loss, train_step))
Expand Down
3 changes: 1 addition & 2 deletions sonnet/python/modules/base.py
Expand Up @@ -26,7 +26,6 @@

import abc
import collections
import types
# Dependency imports
import six
from sonnet.python.modules import util
Expand Down Expand Up @@ -130,7 +129,7 @@ def __init__(self, name=None):
ValueError: If name is not specified.
"""

if name is None or not isinstance(name, types.StringTypes):
if name is None or not isinstance(name, six.string_types):
raise ValueError("Name must be a string.")

self._connected_subgraphs = []
Expand Down
7 changes: 5 additions & 2 deletions sonnet/python/modules/base_test.py
Expand Up @@ -22,6 +22,7 @@
from functools import partial

import numpy as np
import six
from sonnet.python.modules import base
import tensorflow as tf

Expand Down Expand Up @@ -62,8 +63,10 @@ def testInitializerKeys(self):
self.assertEqual(keys, {"foo", "bar"})
keys = ModuleWithNoInitializerKeys.get_possible_initializer_keys()
self.assertEqual(keys, set())
msg = ("missing 1 required positional argument" if six.PY3
else "takes exactly 2 arguments")
self.assertRaisesRegexp(
TypeError, "takes exactly 2 arguments",
TypeError, msg,
ModuleWithCustomInitializerKeys.get_possible_initializer_keys)
keys = ModuleWithCustomInitializerKeys.get_possible_initializer_keys(True)
self.assertEqual(keys, {"foo"})
Expand Down Expand Up @@ -146,7 +149,7 @@ def testFunctionType(self):
with self.assertRaises(TypeError) as cm:
base.Module(build="not_a_function")

self.assertEqual(cm.exception.message, "Input 'build' must be callable.")
self.assertEqual(str(cm.exception), "Input 'build' must be callable.")

def testSharing(self):
batch_size = 3
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/modules/basic.py
Expand Up @@ -27,6 +27,7 @@
# Dependency imports

import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from sonnet.python.modules import base
from sonnet.python.modules import util
from sonnet.python.ops import nest
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/modules/basic_rnn_test.py
Expand Up @@ -22,6 +22,7 @@

# Dependency imports
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import sonnet as snt
from sonnet.testing import parameterized
import tensorflow as tf
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/modules/basic_test.py
Expand Up @@ -23,6 +23,7 @@
# Dependency imports

import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import sonnet as snt
from sonnet.testing import parameterized
import tensorflow as tf
Expand Down
10 changes: 5 additions & 5 deletions sonnet/python/modules/batch_norm.py
Expand Up @@ -306,7 +306,7 @@ def _infer_fused_data_format(self, input_batch):
# Reduce over the second dimension.
return "NCHW"
else:
raise ValueError("Invalid axis option {:s}. This does not correspond to"
raise ValueError("Invalid axis option {}. This does not correspond to"
" either the NHWC format (0, 1, 2) or the NCHW "
"(0, 2, 3).".format(axis))

Expand Down Expand Up @@ -439,23 +439,23 @@ def _build(self, input_batch, is_training=True, test_local_stats=True):
if self._axis is not None:
if len(self._axis) > len(input_shape):
raise base.IncompatibleShapeError(
"Too many indices specified in axis: len({:s}) > len({:s}).".format(
"Too many indices specified in axis: len({}) > len({}).".format(
self._axis, input_shape))

if max(self._axis) >= len(input_shape):
raise base.IncompatibleShapeError(
"One or more index in axis is too large for "
"input shape: {:s} >= {:d}.".format(self._axis, len(input_shape)))
"input shape: {} >= {:d}.".format(self._axis, len(input_shape)))

if min(self._axis) < 0:
raise base.IncompatibleShapeError(
"Indices in axis must be non-negative: {:s} < 0.".format(
"Indices in axis must be non-negative: {} < 0.".format(
self._axis))

axis = self._axis
else:
# Reduce over all dimensions except the last.
axis = range(len(input_shape))[:-1]
axis = tuple(range(len(input_shape))[:-1])

# See following for important note on accuracy for dtype=tf.float16
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L63
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/modules/block_matrix.py
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

# Dependency imports
from six.moves import xrange # pylint: disable=redefined-builtin
from sonnet.python.modules import base
import tensorflow as tf

Expand Down
12 changes: 6 additions & 6 deletions sonnet/python/modules/conv_test.py
Expand Up @@ -753,7 +753,7 @@ def testMaskErrorInvalidRank(self):
with self.assertRaises(snt.Error) as cm:
snt.Conv2D(output_channels=4, kernel_shape=3, mask=mask)
self.assertEqual(
cm.exception.message,
str(cm.exception),
"Invalid mask rank: {}".format(mask.ndim))

def testMaskErrorInvalidType(self):
Expand All @@ -763,7 +763,7 @@ def testMaskErrorInvalidType(self):
with self.assertRaises(TypeError) as cm:
snt.Conv2D(output_channels=4, kernel_shape=3, mask=mask)
self.assertEqual(
cm.exception.message, "Invalid type for mask: {}".format(type(mask)))
str(cm.exception), "Invalid type for mask: {}".format(type(mask)))

def testMaskErrorIncompatibleRank2(self):
"""Errors are thrown for incompatible rank 2 mask."""
Expand All @@ -772,8 +772,8 @@ def testMaskErrorIncompatibleRank2(self):
x = tf.constant(0.0, shape=(2, 8, 8, 6))
with self.assertRaises(snt.Error) as cm:
snt.Conv2D(output_channels=4, kernel_shape=5, mask=mask)(x)
self.assertEqual(
cm.exception.message, "Invalid mask shape: {}".format(mask.shape))
self.assertTrue(str(cm.exception).startswith(
"Invalid mask shape: {}".format(mask.shape)))

def testMaskErrorIncompatibleRank4(self):
"""Errors are thrown for incompatible rank 4 mask."""
Expand All @@ -782,8 +782,8 @@ def testMaskErrorIncompatibleRank4(self):
x = tf.constant(0.0, shape=(2, 8, 8, 6))
with self.assertRaises(snt.Error) as cm:
snt.Conv2D(output_channels=4, kernel_shape=5, mask=mask)(x)
self.assertEqual(
cm.exception.message, "Invalid mask shape: {}".format(mask.shape))
self.assertTrue(str(cm.exception).startswith(
"Invalid mask shape: {}".format(mask.shape)))


class Conv2DTransposeTest(parameterized.ParameterizedTestCase,
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/modules/nets/convnet.py
Expand Up @@ -20,6 +20,7 @@

import collections

from six.moves import xrange # pylint: disable=redefined-builtin
from sonnet.python.modules import base
from sonnet.python.modules import batch_norm
from sonnet.python.modules import conv
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/modules/nets/mlp.py
Expand Up @@ -20,6 +20,7 @@

import collections

from six.moves import xrange # pylint: disable=redefined-builtin
from sonnet.python.modules import base
from sonnet.python.modules import basic
from sonnet.python.modules import util
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/modules/rnn_core.py
Expand Up @@ -29,6 +29,7 @@
# Dependency imports

import six
from six.moves import xrange # pylint: disable=redefined-builtin
from sonnet.python.modules import base
from sonnet.python.modules import basic
import tensorflow as tf
Expand Down
6 changes: 5 additions & 1 deletion sonnet/python/modules/sequential_test.py
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

# Dependency imports
import six
import sonnet as snt
import tensorflow as tf

Expand Down Expand Up @@ -60,7 +61,10 @@ def module1(a, b):
def module2(a, b, c):
return a, b, c

err_str = r"module2\(\) takes exactly 3 arguments \(2 given\)"
if six.PY3:
err_str = r"module2\(\) missing 1 required positional argument: 'c'"
else:
err_str = r"module2\(\) takes exactly 3 arguments \(2 given\)"
with self.assertRaisesRegexp(TypeError, err_str):
_, _ = snt.Sequential([module1, module2], name="seq2")(1, 2)

Expand Down
3 changes: 2 additions & 1 deletion sonnet/python/modules/spatial_transformer.py
Expand Up @@ -23,6 +23,7 @@

# Dependency imports
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from sonnet.python.modules import base
from sonnet.python.modules import basic
import tensorflow as tf
Expand Down Expand Up @@ -419,7 +420,7 @@ def _affine_grid_warper_inverse(inputs):
index = iter(range(6))
def get_variable(constraint):
if constraint is None:
i = index.next()
i = next(index)
return inputs[:, i:i+1]
else:
return tf.fill(constant_shape, tf.constant(constraint,
Expand Down
5 changes: 3 additions & 2 deletions sonnet/python/modules/util.py
Expand Up @@ -21,6 +21,7 @@
import re

# Dependency imports
import six
import tensorflow as tf


Expand Down Expand Up @@ -82,7 +83,7 @@ def _check_nested_callables(dictionary, object_name):
TypeError: If the dictionary contains something that is not either a
dictionary or a callable.
"""
for key, entry in dictionary.iteritems():
for key, entry in six.iteritems(dictionary):
if isinstance(entry, dict):
_check_nested_callables(entry, object_name)
elif not callable(entry):
Expand Down Expand Up @@ -311,7 +312,7 @@ def get_saver(scope, collections=(tf.GraphKeys.GLOBAL_VARIABLES,),

def has_variable_scope(obj):
"""Determines whether the given object has a variable scope."""
return hasattr(obj, "variable_scope") or "variable_scope" in dir(obj)
return "variable_scope" in dir(obj)


def _format_table(rows):
Expand Down
4 changes: 2 additions & 2 deletions sonnet/python/ops/nest.py
Expand Up @@ -290,7 +290,7 @@ def map(fn_or_op, *inputs): # pylint: disable=redefined-builtin
def _sorted(dict_):
"""Returns a sorted list from the dict, with error if keys not sortable."""
try:
return sorted(dict_.iterkeys())
return sorted(six.iterkeys(dict_))
except TypeError:
raise TypeError("nest only supports dicts with sortable keys.")

Expand All @@ -307,7 +307,7 @@ def _iterable_like(instance, args):
`args` with the type of `instance`.
"""
if isinstance(instance, collections.OrderedDict):
return collections.OrderedDict(zip(instance.iterkeys(), args))
return collections.OrderedDict(zip(six.iterkeys(instance), args))
elif isinstance(instance, dict):
return dict(zip(_sorted(instance), args))
elif (isinstance(instance, tuple) and
Expand Down
35 changes: 19 additions & 16 deletions sonnet/python/ops/nest_test.py
Expand Up @@ -24,9 +24,12 @@
# Dependency imports

import numpy as np
import six
from sonnet.python.ops import nest
import tensorflow as tf

typekw = "class" if six.PY3 else "type"


class NestTest(tf.test.TestCase):

Expand All @@ -35,7 +38,7 @@ def testAssertShallowStructure(self):
inp_abc = ["a", "b", "c"]
with self.assertRaises(ValueError) as cm:
nest.assert_shallow_structure(inp_abc, inp_ab)
self.assertEqual(cm.exception.message,
self.assertEqual(str(cm.exception),
"The two structures don't have the same sequence length. "
"Input structure has length 2, while shallow structure "
"has length 3.")
Expand All @@ -44,10 +47,10 @@ def testAssertShallowStructure(self):
inp_ab2 = [[1, 1], [2, 2]]
with self.assertRaises(TypeError) as cm:
nest.assert_shallow_structure(inp_ab2, inp_ab1)
self.assertEqual(cm.exception.message,
self.assertEqual(str(cm.exception),
"The two structures don't have the same sequence type. "
"Input structure has type <type 'tuple'>, while shallow "
"structure has type <type 'list'>.")
"Input structure has type <{0} 'tuple'>, while shallow "
"structure has type <{0} 'list'>.".format(typekw))

def testFlattenUpTo(self):
# Normal application (Example 1).
Expand Down Expand Up @@ -123,19 +126,19 @@ def testFlattenUpTo(self):
with self.assertRaises(TypeError) as cm:
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(cm.exception.message,
"If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <type 'str'>.")
self.assertEqual(str(cm.exception),
"If shallow structure is a sequence, input must also be "
"a sequence. Input has type: <{} 'str'>.".format(typekw))
self.assertEqual(flattened_shallow_tree, shallow_tree)

input_tree = "input_tree"
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
with self.assertRaises(TypeError) as cm:
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(cm.exception.message,
"If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <type 'str'>.")
self.assertEqual(str(cm.exception),
"If shallow structure is a sequence, input must also be "
"a sequence. Input has type: <{} 'str'>.".format(typekw))
self.assertEqual(flattened_shallow_tree, shallow_tree)

# Using non-iterable elements.
Expand All @@ -144,19 +147,19 @@ def testFlattenUpTo(self):
with self.assertRaises(TypeError) as cm:
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(cm.exception.message,
"If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <type 'int'>.")
self.assertEqual(str(cm.exception),
"If shallow structure is a sequence, input must also be "
"a sequence. Input has type: <{} 'int'>.".format(typekw))
self.assertEqual(flattened_shallow_tree, shallow_tree)

input_tree = 0
shallow_tree = [9, 8]
with self.assertRaises(TypeError) as cm:
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(cm.exception.message,
"If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <type 'int'>.")
self.assertEqual(str(cm.exception),
"If shallow structure is a sequence, input must also be "
"a sequence. Input has type: <{} 'int'>.".format(typekw))
self.assertEqual(flattened_shallow_tree, shallow_tree)

def testMapUpTo(self):
Expand Down
1 change: 1 addition & 0 deletions sonnet/python/ops/resampler_test.py
Expand Up @@ -22,6 +22,7 @@
# Dependency imports

import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import sonnet as snt
from sonnet.testing import parameterized

Expand Down

0 comments on commit 578e336

Please sign in to comment.