From 6f934d05a750813734be1bcc6ba33d9924967d7a Mon Sep 17 00:00:00 2001 From: rizar Date: Fri, 7 Jul 2017 16:33:54 -0400 Subject: [PATCH] Filter by unbound application and call_id --- blocks/bricks/base.py | 2 ++ blocks/filter.py | 35 +++++++++++++++++++++++++++-------- tests/test_variable_filter.py | 18 +++++++++++------- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/blocks/bricks/base.py b/blocks/bricks/base.py index 4f830627..eb470c96 100644 --- a/blocks/bricks/base.py +++ b/blocks/bricks/base.py @@ -247,6 +247,7 @@ def __call__(self, brick, *args, **kwargs): def apply(self, bound_application, *args, **kwargs): as_dict = kwargs.pop('as_dict', False) as_list = kwargs.pop('as_list', False) + call_id = kwargs.pop('call_id', None) if as_list and as_dict: raise ValueError @@ -259,6 +260,7 @@ def apply(self, bound_application, *args, **kwargs): # Construct the ApplicationCall, used to store data in for this call call = ApplicationCall(bound_application) + call.metadata['call_id'] = call_id args = list(args) if 'application' in args_names: args.insert(args_names.index('application'), bound_application) diff --git a/blocks/filter.py b/blocks/filter.py index 90765931..cfda8068 100644 --- a/blocks/filter.py +++ b/blocks/filter.py @@ -1,7 +1,8 @@ from inspect import isclass import re -from blocks.bricks.base import ApplicationCall, BoundApplication, Brick +from blocks.bricks.base import ( + Application, ApplicationCall, BoundApplication, Brick) from blocks.roles import has_roles @@ -64,7 +65,11 @@ class VariableFilter(object): theano_name_regex : str, optional A regular expression for the variable name. The Theano name (i.e. `x.name`) is used. - applications : list of :class:`.Application`, optional + call_id : str, optional + The call identifier as written in :class:`.ApplicationCall` metadata + attribute. + applications : list of :class:`.Application` + or :class:`.BoundApplication`, optional Matches a variable that was produced by any of the applications given. @@ -101,13 +106,14 @@ class VariableFilter(object): """ def __init__(self, roles=None, bricks=None, each_role=False, name=None, name_regex=None, theano_name=None, theano_name_regex=None, - applications=None): + call_id=None, applications=None): if bricks is not None and not all( isinstance(brick, Brick) or issubclass(brick, Brick) for brick in bricks): raise ValueError('`bricks` should be a list of Bricks') if applications is not None and not all( - isinstance(application, BoundApplication) + isinstance(application, BoundApplication) or + isinstance(application, Application) for application in applications): raise ValueError('`applications` should be a list of ' 'BoundApplications') @@ -118,6 +124,7 @@ def __init__(self, roles=None, bricks=None, each_role=False, name=None, self.name_regex = name_regex self.theano_name = theano_name self.theano_name_regex = theano_name_regex + self.call_id = call_id self.applications = applications def __call__(self, variables): @@ -162,8 +169,20 @@ def __call__(self, variables): if (var.name is not None) and re.match(self.theano_name_regex, var.name)] if self.applications: - variables = [var for var in variables - if get_application_call(var) and - get_application_call(var).application in - self.applications] + filtered_variables = [] + for var in variables: + var_application = get_application_call(var) + if var_application is None: + continue + if (var_application.application in + self.applications or + var_application.application.application in + self.applications): + filtered_variables.append(var) + variables = filtered_variables + if self.call_id: + variables = [ + var for var in variables + if get_application_call(var) and + get_application_call(var).metadata['call_id'] == self.call_id] return variables diff --git a/tests/test_variable_filter.py b/tests/test_variable_filter.py index 4df5b221..cb1045e8 100644 --- a/tests/test_variable_filter.py +++ b/tests/test_variable_filter.py @@ -16,8 +16,8 @@ def test_variable_filter(): activation = Logistic(name='sigm') x = tensor.vector() - h1 = brick1.apply(x) - h2 = activation.apply(h1) + h1 = brick1.apply(x, call_id='brick1_call_id') + h2 = activation.apply(h1, call_id='act') h2.name = "h2act" y = brick2.apply(h2) cg = ComputationGraph(y) @@ -67,14 +67,18 @@ def test_variable_filter(): theano_name_filter_regex = VariableFilter(theano_name_regex='h2a.?t') assert [cg.variables[11]] == theano_name_filter_regex(cg.variables) + brick1_apply_variables = [cg.variables[1], cg.variables[8]] # Testing filtering by application appli_filter = VariableFilter(applications=[brick1.apply]) - variables = [cg.variables[1], cg.variables[8]] - assert variables == appli_filter(cg.variables) + assert brick1_apply_variables == appli_filter(cg.variables) - # Testing filtering by application - appli_filter_list = VariableFilter(applications=[brick1.apply]) - assert variables == appli_filter_list(cg.variables) + # Testing filtering by unbound application + unbound_appli_filter = VariableFilter(applications=[Linear.apply]) + assert brick1_apply_variables == unbound_appli_filter(cg.variables) + + # Testing filtering by call identifier + call_id_filter = VariableFilter(call_id='brick1_call_id') + assert brick1_apply_variables == call_id_filter(cg.variables) input1 = tensor.matrix('input1') input2 = tensor.matrix('input2')