Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
478 lines (407 sloc) 16 KB
"""Support code for YAML parsing of experiment descriptions."""
import yaml
from pylearn2.utils import serial
from pylearn2.utils.exc import reraise_as
from pylearn2.utils.string_utils import preprocess
from pylearn2.utils.call_check import checked_call
from pylearn2.utils.string_utils import match
from collections import namedtuple
import logging
import warnings
import re
from theano.compat import six
SCIENTIFIC_NOTATION_REGEXP = r'^[\-\+]?(\d+\.?\d*|\d*\.?\d+)?[eE][\-\+]?\d+$'
is_initialized = False
additional_environ = None
logger = logging.getLogger(__name__)
# Lightweight container for initial YAML evaluation.
#
# This is intended as a robust, forward-compatible intermediate representation
# for either internal consumption or external consumption by another tool e.g.
# hyperopt.
#
# We've included a slot for positionals just in case, though they are
# unsupported by the instantiation mechanism as yet.
BaseProxy = namedtuple('BaseProxy', ['callable', 'positionals',
'keywords', 'yaml_src'])
class Proxy(BaseProxy):
"""
An intermediate representation between initial YAML parse and object
instantiation.
Parameters
----------
callable : callable
The function/class to call to instantiate this node.
positionals : iterable
Placeholder for future support for positional arguments (`*args`).
keywords : dict-like
A mapping from keywords to arguments (`**kwargs`), which may be
`Proxy`s or `Proxy`s nested inside `dict` or `list` instances.
Keys must be strings that are valid Python variable names.
yaml_src : str
The YAML source that created this node, if available.
Notes
-----
This is intended as a robust, forward-compatible intermediate
representation for either internal consumption or external consumption
by another tool e.g. hyperopt.
This particular class mainly exists to override `BaseProxy`'s `__hash__`
(to avoid hashing unhashable namedtuple elements).
"""
__slots__ = []
def __hash__(self):
"""
Return a hash based on the object ID (to avoid hashing unhashable
namedtuple elements).
"""
return hash(id(self))
def do_not_recurse(value):
"""
Function symbol used for wrapping an unpickled object (which should
not be recursively expanded). This is recognized and respected by the
instantiation parser. Implementationally, no-op (returns the value
passed in as an argument).
Parameters
----------
value : object
The value to be returned.
Returns
-------
value : object
The same object passed in as an argument.
"""
return value
def _instantiate_proxy_tuple(proxy, bindings=None):
"""
Helper function for `_instantiate` that handles objects of the `Proxy`
class.
Parameters
----------
proxy : Proxy object
A `Proxy` object that.
bindings : dict, opitonal
A dictionary mapping previously instantiated `Proxy` objects
to their instantiated values.
Returns
-------
obj : object
The result object from recursively instantiating the object DAG.
"""
if proxy in bindings:
return bindings[proxy]
else:
# Respect do_not_recurse by just un-packing it (same as calling).
if proxy.callable == do_not_recurse:
obj = proxy.keywords['value']
else:
# TODO: add (requested) support for positionals (needs to be added
# to checked_call also).
if len(proxy.positionals) > 0:
raise NotImplementedError('positional arguments not yet '
'supported in proxy instantiation')
kwargs = dict((k, _instantiate(v, bindings))
for k, v in six.iteritems(proxy.keywords))
obj = checked_call(proxy.callable, kwargs)
try:
obj.yaml_src = proxy.yaml_src
except AttributeError: # Some classes won't allow this.
pass
bindings[proxy] = obj
return bindings[proxy]
def _instantiate(proxy, bindings=None):
"""
Instantiate a (hierarchy of) Proxy object(s).
Parameters
----------
proxy : object
A `Proxy` object or list/dict/literal. Strings are run through
`preprocess`.
bindings : dict, opitonal
A dictionary mapping previously instantiated `Proxy` objects
to their instantiated values.
Returns
-------
obj : object
The result object from recursively instantiating the object DAG.
Notes
-----
This should not be considered part of the stable, public API.
"""
if bindings is None:
bindings = {}
if isinstance(proxy, Proxy):
return _instantiate_proxy_tuple(proxy, bindings)
elif isinstance(proxy, dict):
# Recurse on the keys too, for backward compatibility.
# Is the key instantiation feature ever actually used, by anyone?
return dict((_instantiate(k, bindings), _instantiate(v, bindings))
for k, v in six.iteritems(proxy))
elif isinstance(proxy, list):
return [_instantiate(v, bindings) for v in proxy]
# In the future it might be good to consider a dict argument that provides
# a type->callable mapping for arbitrary transformations like this.
elif isinstance(proxy, six.string_types):
return preprocess(proxy)
else:
return proxy
def load(stream, environ=None, instantiate=True, **kwargs):
"""
Loads a YAML configuration from a string or file-like object.
Parameters
----------
stream : str or object
Either a string containing valid YAML or a file-like object
supporting the .read() interface.
environ : dict, optional
A dictionary used for ${FOO} substitutions in addition to
environment variables. If a key appears both in `os.environ`
and this dictionary, the value in this dictionary is used.
instantiate : bool, optional
If `False`, do not actually instantiate the objects but instead
produce a nested hierarchy of `Proxy` objects.
Returns
-------
graph : dict or object
The dictionary or object (if the top-level element specified
a Python object to instantiate), or a nested hierarchy of
`Proxy` objects.
Notes
-----
Other keyword arguments are passed on to `yaml.load`.
"""
global is_initialized
global additional_environ
if not is_initialized:
initialize()
additional_environ = environ
if isinstance(stream, six.string_types):
string = stream
else:
string = stream.read()
proxy_graph = yaml.load(string, **kwargs)
if instantiate:
return _instantiate(proxy_graph)
else:
return proxy_graph
def load_path(path, environ=None, instantiate=True, **kwargs):
"""
Convenience function for loading a YAML configuration from a file.
Parameters
----------
path : str
The path to the file to load on disk.
environ : dict, optional
A dictionary used for ${FOO} substitutions in addition to
environment variables. If a key appears both in `os.environ`
and this dictionary, the value in this dictionary is used.
instantiate : bool, optional
If `False`, do not actually instantiate the objects but instead
produce a nested hierarchy of `Proxy` objects.
Returns
-------
graph : dict or object
The dictionary or object (if the top-level element specified
a Python object to instantiate), or a nested hierarchy of
`Proxy` objects.
Notes
-----
Other keyword arguments are passed on to `yaml.load`.
"""
with open(path, 'r') as f:
content = ''.join(f.readlines())
# This is apparently here to avoid the odd instance where a file gets
# loaded as Unicode instead (see 03f238c6d). It's rare instance where
# basestring is not the right call.
if not isinstance(content, str):
raise AssertionError("Expected content to be of type str, got " +
str(type(content)))
return load(content, instantiate=instantiate, environ=environ, **kwargs)
def try_to_import(tag_suffix):
"""
.. todo::
WRITEME
"""
components = tag_suffix.split('.')
modulename = '.'.join(components[:-1])
try:
exec('import %s' % modulename)
except ImportError as e:
# We know it's an ImportError, but is it an ImportError related to
# this path,
# or did the module we're importing have an unrelated ImportError?
# and yes, this test can still have false positives, feel free to
# improve it
pieces = modulename.split('.')
str_e = str(e)
found = True in [piece.find(str(e)) != -1 for piece in pieces]
if found:
# The yaml file is probably to blame.
# Report the problem with the full module path from the YAML
# file
reraise_as(ImportError("Could not import %s; ImportError was %s" %
(modulename, str_e)))
else:
pcomponents = components[:-1]
assert len(pcomponents) >= 1
j = 1
while j <= len(pcomponents):
modulename = '.'.join(pcomponents[:j])
try:
exec('import %s' % modulename)
except Exception:
base_msg = 'Could not import %s' % modulename
if j > 1:
modulename = '.'.join(pcomponents[:j - 1])
base_msg += ' but could import %s' % modulename
reraise_as(ImportError(base_msg + '. Original exception: '
+ str(e)))
j += 1
try:
obj = eval(tag_suffix)
except AttributeError as e:
try:
# Try to figure out what the wrong field name was
# If we fail to do it, just fall back to giving the usual
# attribute error
pieces = tag_suffix.split('.')
module = '.'.join(pieces[:-1])
field = pieces[-1]
candidates = dir(eval(module))
msg = ('Could not evaluate %s. ' % tag_suffix +
'Did you mean ' + match(field, candidates) + '? ' +
'Original error was ' + str(e))
except Exception:
warnings.warn("Attempt to decipher AttributeError failed")
reraise_as(AttributeError('Could not evaluate %s. ' % tag_suffix +
'Original error was ' + str(e)))
reraise_as(AttributeError(msg))
return obj
def initialize():
"""
Initialize the configuration system by installing YAML handlers.
Automatically done on first call to load() specified in this file.
"""
global is_initialized
# Add the custom multi-constructor
yaml.add_multi_constructor('!obj:', multi_constructor_obj)
yaml.add_multi_constructor('!pkl:', multi_constructor_pkl)
yaml.add_multi_constructor('!import:', multi_constructor_import)
yaml.add_constructor('!import', constructor_import)
yaml.add_constructor("!float", constructor_float)
pattern = re.compile(SCIENTIFIC_NOTATION_REGEXP)
yaml.add_implicit_resolver('!float', pattern)
is_initialized = True
###############################################################################
# Callbacks used by PyYAML
def multi_constructor_obj(loader, tag_suffix, node):
"""
Callback used by PyYAML when a "!obj:" tag is encountered.
See PyYAML documentation for details on the call signature.
"""
yaml_src = yaml.serialize(node)
construct_mapping(node)
mapping = loader.construct_mapping(node)
assert hasattr(mapping, 'keys')
assert hasattr(mapping, 'values')
for key in mapping.keys():
if not isinstance(key, six.string_types):
message = "Received non string object (%s) as " \
"key in mapping." % str(key)
raise TypeError(message)
if '.' not in tag_suffix:
# TODO: I'm not sure how this was ever working without eval().
callable = eval(tag_suffix)
else:
callable = try_to_import(tag_suffix)
rval = Proxy(callable=callable, yaml_src=yaml_src, positionals=(),
keywords=mapping)
return rval
def multi_constructor_pkl(loader, tag_suffix, node):
"""
Callback used by PyYAML when a "!pkl:" tag is encountered.
"""
global additional_environ
if tag_suffix != "" and tag_suffix != u"":
raise AssertionError('Expected tag_suffix to be "" but it is "'
+ tag_suffix +
'": Put space between !pkl: and the filename.')
mapping = loader.construct_yaml_str(node)
obj = serial.load(preprocess(mapping, additional_environ))
proxy = Proxy(callable=do_not_recurse, positionals=(),
keywords={'value': obj}, yaml_src=yaml.serialize(node))
return proxy
def multi_constructor_import(loader, tag_suffix, node):
"""
Callback used by PyYAML when a "!import:" tag is encountered.
"""
if '.' not in tag_suffix:
raise yaml.YAMLError("!import: tag suffix contains no '.'")
return try_to_import(tag_suffix)
def constructor_import(loader, node):
"""
Callback used by PyYAML when a "!import <str>" tag is encountered.
This tag exects a (quoted) string as argument.
"""
value = loader.construct_scalar(node)
if '.' not in value:
raise yaml.YAMLError("import tag suffix contains no '.'")
return try_to_import(value)
def constructor_float(loader, node):
"""
Callback used by PyYAML when a "!float <str>" tag is encountered.
This tag exects a (quoted) string as argument.
"""
value = loader.construct_scalar(node)
return float(value)
def construct_mapping(node, deep=False):
# This is a modified version of yaml.BaseConstructor.construct_mapping
# in which a repeated key raises a ConstructorError
if not isinstance(node, yaml.nodes.MappingNode):
const = yaml.constructor
message = "expected a mapping node, but found"
raise const.ConstructorError(None, None,
"%s %s " % (message, node.id),
node.start_mark)
mapping = {}
constructor = yaml.constructor.BaseConstructor()
for key_node, value_node in node.value:
key = constructor.construct_object(key_node, deep=False)
try:
hash(key)
except TypeError as exc:
const = yaml.constructor
reraise_as(const.ConstructorError("while constructing a mapping",
node.start_mark,
"found unacceptable key (%s)" %
(exc, key_node.start_mark)))
if key in mapping:
const = yaml.constructor
raise const.ConstructorError("while constructing a mapping",
node.start_mark,
"found duplicate key (%s)" % key)
value = constructor.construct_object(value_node, deep=False)
mapping[key] = value
return mapping
if __name__ == "__main__":
initialize()
# Demonstration of how to specify objects, reference them
# later in the configuration, etc.
yamlfile = """{
"corruptor" : !obj:pylearn2.corruption.GaussianCorruptor &corr {
"corruption_level" : 0.9
},
"dae" : !obj:pylearn2.models.autoencoder.DenoisingAutoencoder {
"nhid" : 20,
"nvis" : 30,
"act_enc" : null,
"act_dec" : null,
"tied_weights" : true,
# we could have also just put the corruptor definition here
"corruptor" : *corr
}
}"""
# yaml.load can take a string or a file object
loaded = yaml.load(yamlfile)
logger.info(loaded)
# These two things should be the same object
assert loaded['corruptor'] is loaded['dae'].corruptor
You can’t perform that action at this time.