diff --git a/brian2/codegen/languages/cpp/cpp.py b/brian2/codegen/languages/cpp/cpp.py index 62b9ccb26..e2dad0eb2 100644 --- a/brian2/codegen/languages/cpp/cpp.py +++ b/brian2/codegen/languages/cpp/cpp.py @@ -12,7 +12,7 @@ from ..base import Language, CodeObject from ..templates import LanguageTemplater -from ...ast_parser import CPPNodeRenderer +from brian2.parsing.rendering import CPPNodeRenderer logger = get_logger(__name__) try: diff --git a/brian2/codegen/languages/python/python.py b/brian2/codegen/languages/python/python.py index ceebe807e..ecbd5b3c2 100644 --- a/brian2/codegen/languages/python/python.py +++ b/brian2/codegen/languages/python/python.py @@ -4,7 +4,7 @@ from ..base import Language, CodeObject from ..templates import LanguageTemplater -from ...ast_parser import NumpyNodeRenderer +from brian2.parsing.rendering import NumpyNodeRenderer __all__ = ['PythonLanguage', 'PythonCodeObject'] diff --git a/brian2/codegen/parsing.py b/brian2/codegen/parsing.py index da32c770d..12e8fcd5c 100644 --- a/brian2/codegen/parsing.py +++ b/brian2/codegen/parsing.py @@ -7,7 +7,7 @@ from sympy.printing.str import StrPrinter from .functions.numpyfunctions import DEFAULT_FUNCTIONS, log10 -from .ast_parser import SympyNodeRenderer +from brian2.parsing.rendering import SympyNodeRenderer def parse_statement(code): diff --git a/brian2/parsing/__init__.py b/brian2/parsing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/brian2/parsing/dependencies.py b/brian2/parsing/dependencies.py new file mode 100644 index 000000000..c00cbc63a --- /dev/null +++ b/brian2/parsing/dependencies.py @@ -0,0 +1,147 @@ +import ast + +from brian2.utils.stringtools import deindent +from collections import namedtuple + +__all__ = ['abstract_code_dependencies'] + +def get_read_write_funcs(parsed_code): + allids = set([]) + read = set([]) + write = set([]) + funcs = set([]) + for node in ast.walk(parsed_code): + if node.__class__ is ast.Name: + allids.add(node.id) + if node.ctx.__class__ is ast.Store: + write.add(node.id) + elif node.ctx.__class__ is ast.Load: + read.add(node.id) + else: + raise SyntaxError + elif node.__class__ is ast.Call: + funcs.add(node.func.id) + + read = read-funcs + + # check that there's no funky stuff going on with functions + if funcs.intersection(write): + raise SyntaxError("Cannot assign to functions in abstract code") + + return allids, read, write, funcs + + +def abstract_code_dependencies(code, known_vars=None, known_funcs=None): + ''' + Analyses identifiers used in abstract code blocks + + Parameters + ---------- + + code : str + The abstract code block. + known_vars : set + The set of known variable names. + known_funcs : set + The set of known function names. + + Returns + ------- + + results : namedtuple with the following fields + ``all`` + The set of all identifiers that appear in this code block, + including functions. + ``read`` + The set of values that are read, excluding functions. + ``write`` + The set of all values that are written to. + ``funcs`` + The set of all function names. + ``known_all`` + The set of all identifiers that appear in this code block and + are known. + ``known_read`` + The set of known values that are read, excluding functions. + ``known_write`` + The set of known values that are written to. + ``known_funcs`` + The set of known functions that are used. + ``unknown_read`` + The set of all unknown variables whose values are read. Equal + to ``read-known_vars``. + ``unknown_write`` + The set of all unknown variables written to. Equal to + ``write-known_vars``. + ``unknown_funcs`` + The set of all unknown function names, equal to + ``funcs-known_funcs``. + ``undefined_read`` + The set of all unknown variables whose values are read before they + are written to. If this set is nonempty it usually indicates an + error, since a variable that is read should either have been + defined in the code block (in which case it will appear in + ``newly_defined``) or already be known. + ``newly_defined`` + The set of all variable names which are newly defined in this + abstract code block. + ''' + if known_vars is None: + known_vars = set([]) + if known_funcs is None: + known_funcs = set([]) + if not isinstance(known_vars, set): + known_vars = set(known_vars) + if not isinstance(known_funcs, set): + known_funcs = set(known_funcs) + + code = deindent(code, docstring=True) + parsed_code = ast.parse(code, mode='exec') + + # Get the list of all variables that are read from and written to, + # ignoring the order + allids, read, write, funcs = get_read_write_funcs(parsed_code) + + # Now check if there are any values that are unknown and read before + # they are written to + defined = known_vars.copy() + newly_defined = set([]) + undefined_read = set([]) + for line in parsed_code.body: + _, cur_read, cur_write, _ = get_read_write_funcs(line) + undef = cur_read-defined + undefined_read |= undef + newly_defined |= (cur_write-defined)-undefined_read + defined |= cur_write + + # Return the results as a named tuple + results = dict( + all=allids, + read=read, + write=write, + funcs=funcs, + known_all=allids.intersection(known_vars.union(known_funcs)), + known_read=read.intersection(known_vars), + known_write=write.intersection(known_vars), + known_funcs=funcs.intersection(known_funcs), + unknown_read=read-known_vars, + unknown_write=write-known_vars, + unknown_funcs=funcs-known_funcs, + undefined_read=undefined_read, + newly_defined=newly_defined, + ) + return namedtuple('AbstractCodeDependencies', results.keys())(**results) + + +if __name__=='__main__': + code = ''' + x = y+z + a = f(b) + ''' + known_vars = set(['y', 'z']) + print deindent(code) + print 'known_vars:', known_vars + print + r = abstract_code_dependencies(code, known_vars) + for k, v in r.__dict__.items(): + print k+':', ', '.join(list(v)) diff --git a/brian2/parsing/expressions.py b/brian2/parsing/expressions.py new file mode 100644 index 000000000..9d1f7dddc --- /dev/null +++ b/brian2/parsing/expressions.py @@ -0,0 +1,235 @@ +''' +AST parsing based analysis of expressions +''' + +import ast + +from brian2.units.fundamentalunits import (get_unit_fast, + DimensionMismatchError, + have_same_dimensions, + ) +from brian2.units import allunits +from brian2.units import stdunits + +__all__ = ['is_boolean_expression', + 'parse_expression_unit',] + +def is_boolean_expression(expr, boolvars=None, boolfuncs=None): + ''' + Determines if an expression is of boolean type or not + + Parameters + ---------- + + expr : str + The expression to test + boolvars : set + The set of variables of boolean type. + boolfuncs : set + The set of functions which return booleans. + + Returns + ------- + + isbool : bool + Whether or not the expression is boolean. + + Raises + ------ + + SyntaxError + If the expression ought to be boolean but is not, + for example ``x0: + subid = subid[:p] + i = int(subid) + alli.append(i) + if len(alli)==0: + i = 0 + else: + i = max(alli)+1 + funcstarts[func.name] = i + + # Now we rewrite all the lines, replacing each line with a sequence of + # lines performing the inlining + newlines = [] + for line in lines: + for func in funcs.values(): + rw = FunctionRewriter(func, funcstarts[func.name]) + line = rw.visit(line) + newlines.extend(rw.pre) + funcstarts[func.name] = rw.numcalls + newlines.append(line) + + # Now we render to a code string + nr = NodeRenderer() + newcode = '\n'.join(nr.render_node(line) for line in newlines) + + # We recurse until no changes in the code to ensure that all functions + # are expanded if one function refers to another, etc. + if newcode==code: + return newcode + else: + return substitute_abstract_code_functions(newcode, funcs) + + +if __name__=='__main__': + if 1: + def f(x): + y = x*x + return y + def g(x): + return f(x)+1 + code = ''' + z = f(x) + z = f(x)+f(y) + w = f(z) + h = f(f(w)) + p = g(g(x)) + ''' + funcs = [abstract_code_from_function(f), + abstract_code_from_function(g), + ] + print substitute_abstract_code_functions(code, funcs) + if 0: + code = ''' + def f(x): + return x*x + def g(V): + V += 1 + ''' + funcs = extract_abstract_code_functions(code) + for k, v in funcs.items(): + print v + if 0: + def f(V, w): + V = w + V += x + x = y*z + return x+y + print abstract_code_from_function(f) + \ No newline at end of file diff --git a/brian2/codegen/ast_parser.py b/brian2/parsing/rendering.py similarity index 89% rename from brian2/codegen/ast_parser.py rename to brian2/parsing/rendering.py index f568e376c..44d876a3d 100644 --- a/brian2/codegen/ast_parser.py +++ b/brian2/parsing/rendering.py @@ -2,7 +2,7 @@ import sympy -from .functions.numpyfunctions import DEFAULT_FUNCTIONS +from brian2.codegen.functions.numpyfunctions import DEFAULT_FUNCTIONS __all__ = ['NodeRenderer', 'NumpyNodeRenderer', @@ -20,8 +20,6 @@ class NodeRenderer(object): 'Div': '/', 'Pow': '**', 'Mod': '%', - 'BitAnd': 'and', - 'BitOr': 'or', # Compare 'Lt': '<', 'LtE': '<=', @@ -31,25 +29,27 @@ class NodeRenderer(object): 'NotEq': '!=', # Unary ops 'Not': 'not', - 'Invert': '~', 'UAdd': '+', 'USub': '-', # Bool ops 'And': 'and', 'Or': 'or', + # Augmented assign + 'AugAdd': '+=', + 'AugSub': '-=', + 'AugMult': '*=', + 'AugDiv': '/=', + 'AugPow': '**=', + 'AugMod': '%=', } def render_expr(self, expr, strip=True): if strip: expr = expr.strip() - expr = expr.replace('&', ' and ') - expr = expr.replace('|', ' or ') node = ast.parse(expr, mode='eval') return self.render_node(node.body) def render_code(self, code): - code = code.replace('&', ' and ') - code = code.replace('|', ' or ') lines = [] for node in ast.parse(code).body: lines.append(self.render_node(node)) @@ -131,17 +131,19 @@ def render_Assign(self, node): raise SyntaxError("Only support syntax like a=b not a=b=c") return '%s = %s' % (self.render_node(node.targets[0]), self.render_node(node.value)) + + def render_AugAssign(self, node): + target = node.target.id + rhs = self.render_node(node.value) + op = self.expression_ops['Aug'+node.op.__class__.__name__] + return '%s %s %s' % (target, op, rhs) class NumpyNodeRenderer(NodeRenderer): expression_ops = NodeRenderer.expression_ops.copy() expression_ops.update({ - # BinOps - 'BitAnd': '*', - 'BitOr': '+', # Unary ops 'Not': 'logical_not', - 'Invert': 'logical_not', # Bool ops 'And': '*', 'Or': '+', @@ -151,15 +153,11 @@ class NumpyNodeRenderer(NodeRenderer): class SympyNodeRenderer(NodeRenderer): expression_ops = NodeRenderer.expression_ops.copy() expression_ops.update({ - # BinOps - 'BitAnd': '&', - 'BitOr': '|', # Compare 'Eq': 'Eq', 'NotEq': 'Ne', # Unary ops 'Not': '~', - 'Invert': '~', # Bool ops 'And': '&', 'Or': '|', @@ -194,12 +192,8 @@ def render_Num(self, node): class CPPNodeRenderer(NodeRenderer): expression_ops = NodeRenderer.expression_ops.copy() expression_ops.update({ - # BinOps - 'BitAnd': '&&', - 'BitOr': '||', # Unary ops 'Not': '!', - 'Invert': '!', # Bool ops 'And': '&&', 'Or': '||', diff --git a/brian2/tests/test_parsing.py b/brian2/tests/test_parsing.py new file mode 100644 index 000000000..b5134162e --- /dev/null +++ b/brian2/tests/test_parsing.py @@ -0,0 +1,276 @@ +''' +Tests the brian2.parsing package +''' +from brian2.utils.stringtools import get_identifiers, deindent +from brian2.parsing.rendering import (NodeRenderer, NumpyNodeRenderer, + CPPNodeRenderer, + ) +from brian2.parsing.dependencies import abstract_code_dependencies +from brian2.parsing.expressions import (is_boolean_expression, + parse_expression_unit) +from brian2.parsing.functions import (abstract_code_from_function, + extract_abstract_code_functions, + substitute_abstract_code_functions) +from brian2.units import volt, amp, DimensionMismatchError, have_same_dimensions + +from numpy.testing import assert_allclose, assert_raises + +import numpy as np +from brian2.codegen.parsing import str_to_sympy, sympy_to_str + +try: + from scipy import weave +except ImportError: + weave = None +import nose + +# TODO: add some tests with e.g. 1.0%2.0 etc. once this is implemented in C++ +TEST_EXPRESSIONS = ''' + a+b+c*d+e-f+g-(b+d)-(a-c) + a**b**2 + a**(b**2) + (a**b)**2 + a*(b+c*(a+b)*(a-(c*d))) + a/b/c-a/(b/c) + ab + a>=b + a==b + a!=b + a+1 + 1+a + 1+3 + a>0.5 and b>0.5 + a>0.5 and b>0.5 or c>0.5 + a>0.5 and b>0.5 or not c>0.5 + 2%4 + ''' + + +def parse_expressions(renderer, evaluator, numvalues=10): + exprs = [([m for m in get_identifiers(l) if len(m)==1], [], l.strip()) + for l in TEST_EXPRESSIONS.split('\n') if l.strip()] + i, imod = 1, 33 + for varids, funcids, expr in exprs: + pexpr = renderer.render_expr(expr) + n = 0 + for _ in xrange(numvalues): + # assign some random values + ns = {} + for v in varids: + ns[v] = float(i)/imod + i = i%imod+1 + r1 = eval(expr.replace('&', ' and ').replace('|', ' or '), ns) + n += 1 + r2 = evaluator(pexpr, ns) + try: + # Use all close because we can introduce small numerical + # difference through sympy's rearrangements + assert_allclose(r1, r2) + except AssertionError as e: + raise AssertionError("In expression " + str(expr) + + " translated to " + str(pexpr) + + " " + str(e)) + + +def numpy_evaluator(expr, userns): + ns = {} + #exec 'from numpy import logical_not' in ns + ns['logical_not'] = np.logical_not + ns.update(**userns) + for k in userns.keys(): + if not k.startswith('_'): + ns[k] = np.array([userns[k]]) + try: + x = eval(expr, ns) + except Exception as e: + raise ValueError("Could not evaluate numpy expression "+expr+" exception "+str(e)) + if isinstance(x, np.ndarray): + return x[0] + else: + return x + + +def cpp_evaluator(expr, ns): + if weave is not None: + return weave.inline('return_val = %s;' % expr, ns.keys(), local_dict=ns, + compiler='gcc') + else: + raise nose.SkipTest('No weave support.') + + +def test_parse_expressions_python(): + parse_expressions(NodeRenderer(), eval) + + +def test_parse_expressions_numpy(): + parse_expressions(NumpyNodeRenderer(), numpy_evaluator) + + +def test_parse_expressions_cpp(): + parse_expressions(CPPNodeRenderer(), cpp_evaluator) + + +def test_parse_expressions_sympy(): + # sympy is about symbolic calculation, the string returned by the renderer + # contains "Symbol('a')" etc. so we cannot simply evaluate it in a + # namespace. + # We therefore use a different approach: Convert the expression to a + # sympy expression via str_to_sympy (uses the SympyNodeRenderer internally), + # then convert it back to a string via sympy_to_str and evaluate it + + class SympyRenderer(object): + def render_expr(self, expr): + return str_to_sympy(expr) + + def evaluator(expr, ns): + expr = sympy_to_str(expr) + return eval(expr, ns) + + parse_expressions(SympyRenderer(), evaluator) + + +def test_abstract_code_dependencies(): + code = ''' + a = b+c + d = b+c + a = func_a() + a = func_b() + a = e+d + ''' + known_vars = set(['a', 'b', 'c']) + known_funcs = set(['func_a']) + res = abstract_code_dependencies(code, known_vars, known_funcs) + expected_res = dict( + all=['a', 'b', 'c', 'd', 'e', + 'func_a', 'func_b', + ], + read=['b', 'c', 'd', 'e'], + write=['a', 'd'], + funcs=['func_a', 'func_b'], + known_all=['a', 'b', 'c', 'func_a'], + known_read=['b', 'c'], + known_write=['a'], + known_funcs=['func_a'], + unknown_read=['d', 'e'], + unknown_write=['d'], + unknown_funcs=['func_b'], + undefined_read=['e'], + newly_defined=['d'], + ) + for k, v in expected_res.items(): + if not getattr(res, k)==set(v): + raise AssertionError("For '%s' result is %s expected %s" % ( + k, getattr(res, k), set(v))) + + +def test_is_boolean_expression(): + EVF = [ + (True, 'a or b', ['a', 'b'], []), + (True, 'True', [], []), + (True, 'ab - a>=b - a==b - a!=b - a+1 - 1+a - 1+3 - a>0.5 and b>0.5 - a>0.5&b>0.5&c>0.5 - (a>0.5) & (b>0.5) & (c>0.5) - a>0.5 and b>0.5 or c>0.5 - a>0.5 and b>0.5 or not c>0.5 - 2%4 - ''' - - -def parse_expressions(renderer, evaluator, numvalues=10): - exprs = [([m for m in get_identifiers(l) if len(m)==1], [], l.strip()) - for l in TEST_EXPRESSIONS.split('\n') if l.strip()] - i, imod = 1, 33 - for varids, funcids, expr in exprs: - pexpr = renderer.render_expr(expr) - n = 0 - for _ in xrange(numvalues): - # assign some random values - ns = {} - for v in varids: - ns[v] = float(i)/imod - i = i%imod+1 - r1 = eval(expr.replace('&', ' and ').replace('|', ' or '), ns) - n += 1 - r2 = evaluator(pexpr, ns) - try: - # Use all close because we can introduce small numerical - # difference through sympy's rearrangements - assert_allclose(r1, r2) - except AssertionError as e: - raise AssertionError("In expression " + str(expr) + - " translated to " + str(pexpr) + - " " + str(e)) - - -def numpy_evaluator(expr, userns): - ns = {} - #exec 'from numpy import logical_not' in ns - ns['logical_not'] = np.logical_not - ns.update(**userns) - for k in userns.keys(): - if not k.startswith('_'): - ns[k] = np.array([userns[k]]) - try: - x = eval(expr, ns) - except Exception as e: - raise ValueError("Could not evaluate numpy expression "+expr+" exception "+str(e)) - if isinstance(x, np.ndarray): - return x[0] - else: - return x - - -def cpp_evaluator(expr, ns): - if weave is not None: - return weave.inline('return_val = %s;' % expr, ns.keys(), local_dict=ns, - compiler='gcc') - else: - raise nose.SkipTest('No weave support.') - - -def test_parse_expressions_python(): - parse_expressions(NodeRenderer(), eval) - - -def test_parse_expressions_numpy(): - parse_expressions(NumpyNodeRenderer(), numpy_evaluator) - - -def test_parse_expressions_cpp(): - parse_expressions(CPPNodeRenderer(), cpp_evaluator) - - -def test_parse_expressions_sympy(): - # sympy is about symbolic calculation, the string returned by the renderer - # contains "Symbol('a')" etc. so we cannot simply evaluate it in a - # namespace. - # We therefore use a different approach: Convert the expression to a - # sympy expression via str_to_sympy (uses the SympyNodeRenderer internally), - # then convert it back to a string via sympy_to_str and evaluate it - - class SympyRenderer(object): - def render_expr(self, expr): - return str_to_sympy(expr) - - def evaluator(expr, ns): - expr = sympy_to_str(expr) - return eval(expr, ns) - - parse_expressions(SympyRenderer(), evaluator) - - -if __name__=='__main__': - test_parse_expressions_python() - test_parse_expressions_numpy() - test_parse_expressions_cpp() - test_parse_expressions_sympy() diff --git a/setup.py b/setup.py index 0e95f50dc..912ec4ef5 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ def run(self): 'brian2.groups', 'brian2.memory', 'brian2.monitors', + 'brian2.parsing', 'brian2.sphinxext', 'brian2.stateupdaters', 'brian2.tests',