Skip to content

Commit

Permalink
Accept string values as variable scope in snt.get_variables_in_scope …
Browse files Browse the repository at this point in the history
…and snt. get_normalized_variable_map.

PiperOrigin-RevId: 161183031
  • Loading branch information
Deepmind authored and diegolascasas committed Jul 10, 2017
1 parent 5e58225 commit 013b000
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
52 changes: 36 additions & 16 deletions sonnet/python/modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@
from tensorflow.python.util import deprecation


def get_variable_scope_name(value):
"""Returns the name of the variable scope indicated by the given value.
Args:
value: String, variable scope, or object with `variable_scope` attribute
(e.g., Sonnet module).
Returns:
The name (a string) of the corresponding variable scope.
Raises:
ValueError: If `value` does not identify a variable scope.
"""
# If the object has a "variable_scope" property, use it.
value = getattr(value, "variable_scope", value)
if isinstance(value, tf.VariableScope):
return value.name
elif isinstance(value, six.string_types):
return value
else:
raise ValueError("Not a variable scope: {}".format(value))


def get_variables_in_scope(scope, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
"""Returns a tuple `tf.Variable`s in a scope for a given collection.
Expand All @@ -44,12 +67,11 @@ def get_variables_in_scope(scope, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
Returns:
A tuple of `tf.Variable` objects.
"""
if isinstance(scope, tf.VariableScope):
scope = scope.name
scope_name = get_variable_scope_name(scope)

# Escape the name in case it contains any "." characters. Add a closing slash
# so we will not search any scopes that have this scope name as a prefix.
scope_name = re.escape(scope) + "/"
scope_name = re.escape(scope_name) + "/"

return tuple(tf.get_collection(collection, scope_name))

Expand Down Expand Up @@ -239,6 +261,9 @@ def check_regularizers(regularizers, keys):
def _is_scope_prefix(scope_name, prefix_name):
"""Checks that `prefix_name` is a proper scope prefix of `scope_name`."""

if not prefix_name:
return True

if not scope_name.endswith("/"):
scope_name += "/"

Expand Down Expand Up @@ -368,24 +393,19 @@ def get_normalized_variable_map(scope_or_module,
Raises:
ValueError: If `context` is given but is not a proper prefix of `scope`.
"""
scope = getattr(scope_or_module, "variable_scope", scope_or_module)
scope_name = get_variable_scope_name(scope_or_module)

if context is None:
context = scope
context_scope = getattr(context, "variable_scope", context)
context = scope_or_module

scope_name = scope.name
prefix = context_scope.name
if prefix:
if not _is_scope_prefix(scope_name, prefix):
raise ValueError("Scope '{}' is not prefixed by '{}'.".format(
scope_name, prefix))
prefix = get_variable_scope_name(context)
prefix_length = len(prefix) + 1 if prefix else 0

prefix_length = len(prefix) + 1
else:
prefix_length = 0
if not _is_scope_prefix(scope_name, prefix):
raise ValueError("Scope '{}' is not prefixed by '{}'.".format(
scope_name, prefix))

variables = get_variables_in_scope(scope, collection)
variables = get_variables_in_scope(scope_name, collection)

if not group_sliced_variables:
single_vars = variables
Expand Down
14 changes: 12 additions & 2 deletions sonnet/python/modules/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def testScopeQuery(self):
self.assertEqual(set(snt.get_variables_in_scope(s2.name)), {v2, v3})

def testIsScopePrefix(self):
self.assertTrue(util._is_scope_prefix("a/b/c", ""))
self.assertTrue(util._is_scope_prefix("a/b/c", "a/b/c"))
self.assertTrue(util._is_scope_prefix("a/b/c", "a/b"))
self.assertTrue(util._is_scope_prefix("a/b/c", "a"))
Expand Down Expand Up @@ -107,6 +108,10 @@ def testGetNormalizedVariableMapScopeContext(self):
variable_map = snt.get_normalized_variable_map(s2, context=s3)

variable_map = snt.get_normalized_variable_map(s2, context=s1)
self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s1),
variable_map)
self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s1.name),
variable_map)

self.assertEqual(len(variable_map), 2)
self.assertIn("prefix2/a:0", variable_map)
Expand All @@ -116,7 +121,12 @@ def testGetNormalizedVariableMapScopeContext(self):

with tf.variable_scope("") as s4:
self.assertEqual(s4.name, "")
variable_map = snt.get_normalized_variable_map(s2, context=s4)

variable_map = snt.get_normalized_variable_map(s2, context=s4)
self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s4),
variable_map)
self.assertEqual(snt.get_normalized_variable_map(s2.name, context=s4.name),
variable_map)

self.assertEqual(len(variable_map), 2)
self.assertIn("prefix1/prefix2/a:0", variable_map)
Expand All @@ -137,7 +147,7 @@ def testGetNormalizedVariableMapModule(self):
self.assertIs(variable_map["w:0"], conv.w)
self.assertIs(variable_map["b:0"], conv.b)

def testGetNormalizedVariableMapWithPartionedVariable(self):
def testGetNormalizedVariableMapWithPartitionedVariable(self):
hidden = tf.ones(shape=(1, 16, 16, 3))
partitioner = tf.variable_axis_size_partitioner(4)
conv = snt.Conv2D(output_channels=3,
Expand Down

0 comments on commit 013b000

Please sign in to comment.