Skip to content

Commit

Permalink
Merge pull request #1194 from rizar/improve_variable_filter
Browse files Browse the repository at this point in the history
Filter by unbound application and call_id
  • Loading branch information
rizar committed Jul 13, 2017
2 parents 2babb42 + 6f934d0 commit 979d0f9
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
2 changes: 2 additions & 0 deletions blocks/bricks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def __call__(self, brick, *args, **kwargs):
def apply(self, bound_application, *args, **kwargs):
as_dict = kwargs.pop('as_dict', False)
as_list = kwargs.pop('as_list', False)
call_id = kwargs.pop('call_id', None)
if as_list and as_dict:
raise ValueError

Expand All @@ -259,6 +260,7 @@ def apply(self, bound_application, *args, **kwargs):

# Construct the ApplicationCall, used to store data in for this call
call = ApplicationCall(bound_application)
call.metadata['call_id'] = call_id
args = list(args)
if 'application' in args_names:
args.insert(args_names.index('application'), bound_application)
Expand Down
35 changes: 27 additions & 8 deletions blocks/filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from inspect import isclass
import re

from blocks.bricks.base import ApplicationCall, BoundApplication, Brick
from blocks.bricks.base import (
Application, ApplicationCall, BoundApplication, Brick)
from blocks.roles import has_roles


Expand Down Expand Up @@ -64,7 +65,11 @@ class VariableFilter(object):
theano_name_regex : str, optional
A regular expression for the variable name. The Theano name (i.e.
`x.name`) is used.
applications : list of :class:`.Application`, optional
call_id : str, optional
The call identifier as written in :class:`.ApplicationCall` metadata
attribute.
applications : list of :class:`.Application`
or :class:`.BoundApplication`, optional
Matches a variable that was produced by any of the applications
given.
Expand Down Expand Up @@ -101,13 +106,14 @@ class VariableFilter(object):
"""
def __init__(self, roles=None, bricks=None, each_role=False, name=None,
name_regex=None, theano_name=None, theano_name_regex=None,
applications=None):
call_id=None, applications=None):
if bricks is not None and not all(
isinstance(brick, Brick) or issubclass(brick, Brick)
for brick in bricks):
raise ValueError('`bricks` should be a list of Bricks')
if applications is not None and not all(
isinstance(application, BoundApplication)
isinstance(application, BoundApplication) or
isinstance(application, Application)
for application in applications):
raise ValueError('`applications` should be a list of '
'BoundApplications')
Expand All @@ -118,6 +124,7 @@ def __init__(self, roles=None, bricks=None, each_role=False, name=None,
self.name_regex = name_regex
self.theano_name = theano_name
self.theano_name_regex = theano_name_regex
self.call_id = call_id
self.applications = applications

def __call__(self, variables):
Expand Down Expand Up @@ -162,8 +169,20 @@ def __call__(self, variables):
if (var.name is not None) and
re.match(self.theano_name_regex, var.name)]
if self.applications:
variables = [var for var in variables
if get_application_call(var) and
get_application_call(var).application in
self.applications]
filtered_variables = []
for var in variables:
var_application = get_application_call(var)
if var_application is None:
continue
if (var_application.application in
self.applications or
var_application.application.application in
self.applications):
filtered_variables.append(var)
variables = filtered_variables
if self.call_id:
variables = [
var for var in variables
if get_application_call(var) and
get_application_call(var).metadata['call_id'] == self.call_id]
return variables
18 changes: 11 additions & 7 deletions tests/test_variable_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def test_variable_filter():
activation = Logistic(name='sigm')

x = tensor.vector()
h1 = brick1.apply(x)
h2 = activation.apply(h1)
h1 = brick1.apply(x, call_id='brick1_call_id')
h2 = activation.apply(h1, call_id='act')
h2.name = "h2act"
y = brick2.apply(h2)
cg = ComputationGraph(y)
Expand Down Expand Up @@ -67,14 +67,18 @@ def test_variable_filter():
theano_name_filter_regex = VariableFilter(theano_name_regex='h2a.?t')
assert [cg.variables[11]] == theano_name_filter_regex(cg.variables)

brick1_apply_variables = [cg.variables[1], cg.variables[8]]
# Testing filtering by application
appli_filter = VariableFilter(applications=[brick1.apply])
variables = [cg.variables[1], cg.variables[8]]
assert variables == appli_filter(cg.variables)
assert brick1_apply_variables == appli_filter(cg.variables)

# Testing filtering by application
appli_filter_list = VariableFilter(applications=[brick1.apply])
assert variables == appli_filter_list(cg.variables)
# Testing filtering by unbound application
unbound_appli_filter = VariableFilter(applications=[Linear.apply])
assert brick1_apply_variables == unbound_appli_filter(cg.variables)

# Testing filtering by call identifier
call_id_filter = VariableFilter(call_id='brick1_call_id')
assert brick1_apply_variables == call_id_filter(cg.variables)

input1 = tensor.matrix('input1')
input2 = tensor.matrix('input2')
Expand Down

0 comments on commit 979d0f9

Please sign in to comment.