Skip to content

Commit

Permalink
Merge pull request #401 from brian-team/fix_lio
Browse files Browse the repository at this point in the history
Fix loop-invariant optimisation (don't pull out integer and boolean expressions). Closes #400
  • Loading branch information
mstimberg committed Feb 10, 2015
2 parents fda17e6 + 8b9063a commit 1fdb52c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 11 deletions.
47 changes: 43 additions & 4 deletions brian2/codegen/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def analyse_identifiers(code, variables, recursive=False):
return defined, used_known, dependent


def get_identifiers_recursively(expressions, variables):
def get_identifiers_recursively(expressions, variables, include_numbers=False):
'''
Gets all the identifiers in a list of expressions, recursing down into
subexpressions.
Expand All @@ -127,15 +127,19 @@ def get_identifiers_recursively(expressions, variables):
List of expressions to check.
variables : dict-like
Dictionary of `Variable` objects
include_numbers : bool, optional
Whether to include number literals in the output. Defaults to ``False``.
'''
if len(expressions):
identifiers = set.union(*[get_identifiers(expr) for expr in expressions])
identifiers = set.union(*[get_identifiers(expr, include_numbers=include_numbers)
for expr in expressions])
else:
identifiers = set()
for name in set(identifiers):
if name in variables and isinstance(variables[name], Subexpression):
s_identifiers = get_identifiers_recursively([variables[name].expr],
variables)
variables,
include_numbers=include_numbers)
identifiers |= s_identifiers
return identifiers

Expand Down Expand Up @@ -167,6 +171,40 @@ def is_scalar_expression(expr, variables):
for name in identifiers)


def has_non_float(expr, variables):
'''
Whether the given expression has an integer or boolean variable in it.
Parameters
----------
expr : str
The expression to check
variables : dict-like
`Variable` and `Function` object for all the identifiers used in `expr`
Returns
-------
has_non_float : bool
Whether `expr` has an integer or boolean in it
'''
identifiers = get_identifiers_recursively([expr], variables,
include_numbers=True)
# Check whether there is an integer literal in the expression:
for name in identifiers:
if name not in variables:
try:
int(name)
# if this worked, this was an integer literal
return True
except (TypeError, ValueError):
pass # not an integer literal
non_float_var = any((name in variables and isinstance(name, Variable) and
(np.issubdtype(variables[name].dtype, np.integer) or
np.issubdtype(variables[name].dtype, np.bool_)))
for name in identifiers)
return non_float_var


class LIONodeRenderer(NodeRenderer):
'''
Renders expressions, pulling out scalar expressions and remembering them
Expand All @@ -185,7 +223,8 @@ def render_node(self, node):
if node.__class__.__name__ in ['Name', 'Num', 'NameConstant']:
return expr

if is_scalar_expression(expr, self.variables):
if is_scalar_expression(expr, self.variables) and not has_non_float(expr,
self.variables):
if expr in self.optimisations:
name = self.optimisations[expr]
else:
Expand Down
13 changes: 12 additions & 1 deletion brian2/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,19 @@ def test_apply_loop_invariant_optimisation():
assert all('_lio_const_1' in stmt.expr for stmt in vector)


@attr('codegen-independent')
def test_apply_loop_invariant_optimisation_integer():
variables = {'v': Variable('v', Unit(1), scalar=False),
'N': Constant('N', Unit(1), 10)}
statements = [Statement('v', '=', 'v % (2*3*N)', '', np.float32)]
scalar, vector = apply_loop_invariant_optimisations(statements, variables,
np.float64)
# The optimisation should not pull out 2*N
assert len(scalar) == 0

if __name__ == '__main__':
test_analyse_identifiers()
test_get_identifiers_recursively()
test_nested_subexpressions()
test_apply_loop_invariant_optimisation()
test_apply_loop_invariant_optimisation()
test_apply_loop_invariant_optimisation_integer()
34 changes: 28 additions & 6 deletions brian2/utils/stringtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,44 @@ def replace(s, substitutions):

KEYWORDS = set(['and', 'or', 'not', 'True', 'False'])

def get_identifiers(expr):

def get_identifiers(expr, include_numbers=False):
'''
Return all the identifiers in a given string ``expr``, that is everything
that matches a programming language variable like expression, which is
here implemented as the regexp ``\\b[A-Za-z_][A-Za-z0-9_]*\\b``.
Parameters
----------
expr : str
The string to analyze
include_numbers : bool, optional
Whether to include number literals in the output. Defaults to ``False``.
Returns
-------
identifiers : set
A set of all the identifiers (and, optionally, numbers) in `expr`.
Examples
--------
>>> expr = 'a*_b+c5+8+f(A)'
>>> expr = '3-a*_b+c5+8+f(A - .3e-10, tau_2)*17'
>>> ids = get_identifiers(expr)
>>> print(sorted(list(ids)))
['A', '_b', 'a', 'c5', 'f']
['A', '_b', 'a', 'c5', 'f', 'tau_2']
>>> ids = get_identifiers(expr, include_numbers=True)
>>> print(sorted(list(ids)))
['.3e-10', '17', '3', '8', 'A', '_b', 'a', 'c5', 'f', 'tau_2']
'''
identifiers = set(re.findall(r'\b[A-Za-z_][A-Za-z0-9_]*\b', expr))
return identifiers - KEYWORDS
if include_numbers:
# only the number, not a + or -
numbers = set(re.findall(r'(?<=[^A-Za-z_])[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?|^[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?',
expr))
else:
numbers = set()
return (identifiers - KEYWORDS) | numbers


def strip_empty_lines(s):
'''
Expand Down

0 comments on commit 1fdb52c

Please sign in to comment.