Skip to content

Commit

Permalink
Implementation of get_all_variables() for sonnet modules.
Browse files Browse the repository at this point in the history
We introduce a module call stack, which tracks the order in which modules are called. When a module enters __call__ (or _enter_variable_scope) it adds itself to the top of the stack. Variables created inside of the custom_getter are added to a collection specific to the module on the top of the stack. Before exiting __call__ (or _enter_variable_scope) the module moves all variables added to this graph collection into `_all_variables`, removes itself from the top of the stack, and adds all of the variables from `self._all_variables` to collection for the module that is currently at the top of the module stack.

PiperOrigin-RevId: 185664981
  • Loading branch information
fastturtle authored and diegolascasas committed Feb 21, 2018
1 parent 06b8e4f commit 62f4399
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 37 deletions.
32 changes: 19 additions & 13 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -757,19 +757,25 @@ without regard for what library was used to put that graph together.

### Q: How do I list all the variables which are used _in any way_ in a Module?

A: Currently, not easily possible. Although there is a `get_variables()` method,
it only searches the `VariableScope` defined inside a module, which
will contain any internally constructed variables or modules. However, the
actual _computation_ done by a module could use other modules - for
example, the `snt.Sequential` module in the example section above. The modules
passed into the constructor have by definition been constructed before the
`Sequential`, and so they have different variable scopes. Currently, once
the Sequential is connected into the graph, querying it with
`get_variables()` will return an empty tuple.

The DeepMind Research Engineering team is considering future additions to the
`Module` API which remedy this, without requiring extra effort from module
implementors.
A: You can use `get_all_variables()` to find all the variables that a module or
any of its submodules have created with `tf.get_variable()`.

Like `get_variables()` this returns all variables that are inside of the module's
(variable) scope. However, `get_all_variables()` also returns all of the
variables from any submodules with disjoint (variable) scopes. These submodules
have either been passed into the module's constructor, or have been constructed
by the module but outside of `_build()` or `_enter_variable_scope()`.

Note that by definition this will not return variables that have not been
created by `tf.get_variable()`. This is relevant for modules that use
`@snt.reuse_variables`. If a method decorated with `@snt.reuse_variable` is
not called then `get_all_variables()` will not return any variables used inside
of it.

Note that by definition this returns _all_ of a module's variables. This means
that a module will return _all_ its submodule's variables, even if it only uses
a subset of the submodule's variables (ie. it does not call a method decorated
by `@snt.reuse_variables` on the submodule).

### Q: How do I serialize Sonnet module instances?

Expand Down
161 changes: 155 additions & 6 deletions sonnet/python/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@

import abc
import collections
import contextlib
import functools
import inspect
import weakref

# Dependency imports
import six
from sonnet.python.modules import base_info
from sonnet.python.modules import util
import tensorflow as tf

# Import error class from base_errors for backward compability.
# Import error class from base_errors for backward compatibility.

from sonnet.python.modules.base_errors import Error
from sonnet.python.modules.base_errors import NotConnectedError
Expand All @@ -49,6 +52,89 @@
# pylint: enable=unused-import


# Maps `tf.Graph` objects to a module call stack.
_MODULE_STACKS = weakref.WeakKeyDictionary()



def _maybe_wrap_custom_getter(custom_getter, old_getter):
"""Wrap a call to a custom_getter to use the old_getter internally.
Copied from [variable_scope._maybe_wrap_custom_getter](
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/
ops/variable_scope.py#L1565)
Args:
custom_getter: The wrapping custom getter.
old_getter: The wrapped custom getter.
Returns:
A new custom getter that calls `old_getter` and then `custom_getter`.
"""
if old_getter is None:
return custom_getter

# The new custom_getter should call the old one
def wrapped_custom_getter(getter, *args, **kwargs):
# Call:
# custom_getter(
# lambda: old_getter(true_getter, ...), *args, **kwargs)
# which means custom_getter will call old_getter, which
# will call the true_getter, perform any intermediate
# processing, and return the results to the current
# getter, which will also perform additional processing.
return custom_getter(functools.partial(old_getter, getter), *args, **kwargs)

return wrapped_custom_getter


def _variable_tracking_custom_getter(getter, *args, **kwargs):
"""Custom getter that tracks variables created.
This custom getter places any variables that `getter` creates into the
`_all_variables` attribute of the `AbstractModule` that is on top of the
module call stack. The module call stack is a graph-dependent stack that
keeps track of the sonnet module call order.
Note that this assumes that variables added appended to `tf.Graph`
collections. This is a safe assumption to make because
`tf.add_to_collection()` appends objects to collections, and `tf.Variable`
uses `tf.add_to_collections()` to add itself to `tf.Graph` collections.
Note that this assumes that all variables are added either the
`tf.GraphKeys.GLOBAL_VARIABLES` or `tf.GraphKeys.LOCAL_VARIABLES` collection.
Args:
getter: The true getter or another custom getter.
*args: See positional arguments for `tf.get_variable()`.
**kwargs: See keyword arguments for `tf.get_variable()`.
Returns:
See docstring for `tf.get_variable()`.
"""
# Get the module that is calling `tf.get_variable()`
module_stack = _MODULE_STACKS[tf.get_default_graph()]
module = module_stack[-1]

# Get lists of local and global variables. We use `tf.get_collection_ref()`
# instead of `tf.get_collection()` to avoid copying the collections.
local_variables = tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)
global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)

num_local_vars_before = len(local_variables)
num_global_vars_before = len(global_variables)

out = getter(*args, **kwargs)

# Add any local or global variables that have been created to `module`
# pylint: disable=protected-access
module._all_variables.update(local_variables[num_local_vars_before:])
module._all_variables.update(global_variables[num_global_vars_before:])
# pylint: enable=protected-access

return out


@six.add_metaclass(abc.ABCMeta)
class AbstractModule(object):
"""Superclass for Sonnet Modules.
Expand Down Expand Up @@ -117,6 +203,8 @@ def __init__(self, _sentinel=None, custom_getter=None,
if not (custom_getter is None or callable(custom_getter)):
raise TypeError("Given custom_getter is not callable.")
self._custom_getter = custom_getter
self._custom_getter = _maybe_wrap_custom_getter(
_variable_tracking_custom_getter, self._custom_getter)

self._template = tf.make_template(name,
self._build_wrapper,
Expand All @@ -126,7 +214,7 @@ def __init__(self, _sentinel=None, custom_getter=None,
self._original_name = name
self._unique_name = self._template.variable_scope.name.split("/")[-1]

# Update __call__ and the object docstrings to enable better introspection
# Update __call__ and the object docstrings to enable better introspection.
self.__doc__ = self._build.__doc__
self.__call__.__func__.__doc__ = self._build.__doc__

Expand All @@ -135,6 +223,9 @@ def __init__(self, _sentinel=None, custom_getter=None,
# sharing is impossible in that case.
self._graph = None

# Container for all variables created in this module and its sub-modules.
self._all_variables = set([])

def _build_wrapper(self, *args, **kwargs):
"""Function which will be wrapped in a Template to do variable sharing.
Expand Down Expand Up @@ -218,6 +309,36 @@ def _build(self, *args, **kwargs):
output Tensor(s).
"""

@contextlib.contextmanager
def _capture_variables(self):
"""Adds variables used by this module to self._all_variables.
Upon entering this context manager the module adds itself onto the top
of the module call stack. Any variables created with `tf.get_variable()`
inside ` _build()` or `_enter_variable_scope()` while this module is on top
of the call stack will be added to `self._all_variables`.
Before exiting the context the module removes itself from the top of the
call stack, and adds all of the variables in `self._all_variables` to its
parent module (the new top) of the call stack.
Yields:
Nothing, the yield just transfers focus back to the inner context.
"""
module_stack = _MODULE_STACKS.setdefault(self._graph, [])
module_stack.append(self)
try:
yield
finally:
# Remove `self` from `module_stack`, this happens as part of cleanup
# even if an error is raised.
module_stack.pop()

if module_stack:
# Peek into the stack to add created variables to the parent
parent_module = module_stack[-1]
parent_module._all_variables.update(self._all_variables) # pylint: disable=protected-access

def _add_connected_subgraph(self, call_method, outputs, subgraph_name_scope,
*inputs_args, **inputs_kwargs):
"""Adds a newly connected subgraph.
Expand Down Expand Up @@ -261,7 +382,8 @@ def __call__(self, *args, **kwargs):
"""
self._check_init_called()
self._check_same_graph()
outputs, subgraph_name_scope = self._template(*args, **kwargs)
with self._capture_variables():
outputs, subgraph_name_scope = self._template(*args, **kwargs)
self._add_connected_subgraph(self._build, outputs, subgraph_name_scope,
*args, **kwargs)
return outputs
Expand Down Expand Up @@ -361,6 +483,8 @@ def _ensure_is_connected(self):
"Variables in {} not instantiated yet, __call__ the module "
"first.".format(self.scope_name))

# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
def _enter_variable_scope(self, reuse=None):
"""Returns a contextlib.contextmanager to enter the internal variable scope.
Expand Down Expand Up @@ -396,12 +520,15 @@ def _build(self, input):
Args:
reuse: Boolean passed to `tf.variable_scope`.
Returns:
`contextlib.contextmanager` of the variable_scope inside the template.
Yields:
The variable_scope inside the template.
"""
self._check_init_called()
self._check_same_graph()
return tf.variable_scope(self._template.variable_scope, reuse=reuse)
with self._capture_variables():
with tf.variable_scope(self._template.variable_scope, reuse=reuse) as vs:
yield vs
# pylint: enable=g-doc-return-or-yield

def get_variables(self, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
"""Returns tuple of `tf.Variable`s declared inside this module.
Expand Down Expand Up @@ -433,6 +560,28 @@ def get_variables(self, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
return util.get_variables_in_scope(
self.variable_scope, collection=collection)

def get_all_variables(self, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
"""Returns all `tf.Variable`s used when the module is connected.
See the documentation for `AbstractModule._capture_variables()` for more
information.
Args:
collection: Collection to restrict query to. By default this is
`tf.Graphkeys.TRAINABLE_VARIABLES`, which doesn't include non-trainable
variables such as moving averages.
Returns:
A tuple of `tf.Variable` objects.
Raises:
NotConnectedError: If the module is not connected to the Graph.
"""
self._ensure_is_connected()
collection_variables = set(tf.get_collection(collection))
# Return variables in self._all_variables that are in `collection`
return tuple(self._all_variables & collection_variables)

def __getstate__(self):
raise NotSupportedError(
"Sonnet AbstractModule instances cannot be serialized. You should "
Expand Down
Loading

0 comments on commit 62f4399

Please sign in to comment.