Skip to content

Commit

Permalink
Merge pull request #1139 from dwf/move_copy_and_tag
Browse files Browse the repository at this point in the history
Move copy-and-tag to bricks.base namespace.
  • Loading branch information
dwf committed Aug 18, 2016
2 parents 4e266ad + 97c8082 commit 508344c
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions blocks/bricks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,28 +269,19 @@ def apply(self, bound_application, *args, **kwargs):
brick.allocate()

# Annotate all the input variables which are Theano variables
def copy_and_tag(variable, role, name):
"""Helper method to copy a variable and annotate it."""
copy = variable.copy()
# Theano name
copy.name = _variable_name(brick.name, self.name, name)
add_annotation(copy, brick)
add_annotation(copy, call)
# Blocks name
copy.tag.name = name
add_role(copy, role)
return copy

for i, input_ in enumerate(args):
if isinstance(input_, tensor.Variable):
if i < len(args_names):
name = args_names[i]
else:
name = "{}_{}".format(varargs_name, i - len(args_names))
args[i] = copy_and_tag(input_, INPUT, name)
args[i] = copy_and_tag(input_, brick, call, INPUT,
self.name, name)
for name, input_ in kwargs.items():
if isinstance(input_, tensor.Variable):
kwargs[name] = copy_and_tag(input_, INPUT, name)
kwargs[name] = copy_and_tag(input_, brick, call, INPUT,
self.name, name)

# Run the application method on the annotated variables
last_brick = self.call_stack[-1] if self.call_stack else None
Expand Down Expand Up @@ -318,8 +309,8 @@ def copy_and_tag(variable, role, name):
except IndexError:
reraise_as(ValueError("Unexpected outputs"))
# TODO Tag with dimensions, axes, etc. for error-checking
outputs[i] = copy_and_tag(outputs[i],
OUTPUT, name)
outputs[i] = copy_and_tag(outputs[i], brick, call,
OUTPUT, self.name, name)

# Return values
if as_list:
Expand Down Expand Up @@ -973,3 +964,16 @@ def wrap_application(application_function):

def _variable_name(brick_name, application_name, name):
return "{}_{}_{}".format(brick_name, application_name, name)


def copy_and_tag(variable, brick, call, role, application_name, name):
"""Helper method to copy a variable and annotate it."""
copy = variable.copy()
# Theano name
copy.name = _variable_name(brick.name, application_name, name)
add_annotation(copy, brick)
add_annotation(copy, call)
# Blocks name
copy.tag.name = name
add_role(copy, role)
return copy

0 comments on commit 508344c

Please sign in to comment.