Skip to content

Commit

Permalink
Merge pull request #1097 from rizar/get_hierarchical_path
Browse files Browse the repository at this point in the history
Unique path for all parameters
  • Loading branch information
dmitriy-serdyuk committed Jun 1, 2016
2 parents 57458e3 + 3a2f1b8 commit 67f35d2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
2 changes: 1 addition & 1 deletion blocks/bricks/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def __init__(self, match_dim, state_transformer=None,
children = [self.state_transformers, attended_transformer,
energy_computer]
kwargs.setdefault('children', []).extend(children)
super(SequenceContentAttention, self).__init__(**kwargs)
super(SequenceContentAttention, self).__init__(**kwargs)

def _push_allocation_config(self):
self.state_transformers.input_dims = self.state_dims
Expand Down
20 changes: 20 additions & 0 deletions blocks/bricks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from blocks.utils import dict_union, pack, repr_attrs, reraise_as, unpack
from blocks.utils.containers import AnnotatingList

BRICK_DELIMITER = '/'


def create_unbound_method(func, cls):
"""Create an unbounded method from a function and a class.
Expand Down Expand Up @@ -773,6 +775,24 @@ def get_unique_path(self):
else:
return [self]

def get_hierarchical_name(self, parameter, delimiter=BRICK_DELIMITER):
"""Return hierarhical name for a parameter.
Returns a path of the form _brick1/brick2/brick3.parameter1_. The
delimiter is configurable.
Parameters
----------
delimiter : str
The delimiter used to separate brick names in the path.
"""
return '{}.{}'.format(
delimiter.join(
[""] + [brick.name for brick in
self.get_unique_path()]),
parameter.name)


def args_to_kwargs(args, f):
arg_names, vararg_names, _, _ = inspect.getargspec(f)
Expand Down
11 changes: 4 additions & 7 deletions blocks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from itertools import chain

from blocks.graph import ComputationGraph
from blocks.select import Selector
from blocks.filter import get_brick

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,14 +74,12 @@ def __init__(self, *args, **kwargs):
if repeated_names:
raise ValueError("top bricks with the same name:"
" {}".format(', '.join(repeated_names)))
brick_parameter_names = {
v: k for k, v in Selector(
self.top_bricks).get_parameters().items()}
parameter_list = []
for parameter in self.parameters:
if parameter in brick_parameter_names:
parameter_list.append((brick_parameter_names[parameter],
parameter))
if get_brick(parameter):
parameter_list.append(
(get_brick(parameter).get_hierarchical_name(parameter),
parameter))
else:
parameter_list.append((parameter.name, parameter))
self._parameter_dict = OrderedDict(parameter_list)
Expand Down
13 changes: 6 additions & 7 deletions blocks/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,12 @@
from blocks.config import config
from blocks.filter import get_brick
from blocks.utils import change_recursion_limit
from blocks.bricks.base import BRICK_DELIMITER


logger = logging.getLogger(__name__)

BRICK_DELIMITER = '|'
SERIALIZATION_BRICK_DELIMITER = '|'
MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \
`__main__` namespace.
Expand Down Expand Up @@ -293,7 +294,8 @@ def load_parameters(file_):
"""
with closing(_load_parameters_npzfile(file_)) as npz_file:
return {name.replace(BRICK_DELIMITER, '/'): value
return {name.replace(SERIALIZATION_BRICK_DELIMITER,
BRICK_DELIMITER): value
for name, value in npz_file.items()}


Expand Down Expand Up @@ -527,11 +529,8 @@ def __init__(self):
def __call__(self, parameter):
# Standard Blocks parameter
if get_brick(parameter) is not None:
name = '{}.{}'.format(
BRICK_DELIMITER.join(
[""] + [brick.name for brick in
get_brick(parameter).get_unique_path()]),
parameter.name)
name = get_brick(parameter).get_hierarchical_name(
parameter, SERIALIZATION_BRICK_DELIMITER)
# Shared variables with tag.name
elif hasattr(parameter.tag, 'name'):
name = parameter.tag.name
Expand Down

0 comments on commit 67f35d2

Please sign in to comment.