Skip to content

Commit

Permalink
Merge pull request #72 from brian-team/ast_improvements
Browse files Browse the repository at this point in the history
Improvements to the AST-based parsing, now a parsing-based approach to unit checking (instead of an eval-based) is used.
  • Loading branch information
Marcel Stimberg committed Jul 8, 2013
2 parents d3d51bd + 1255eaa commit a2fab76
Show file tree
Hide file tree
Showing 18 changed files with 1,221 additions and 248 deletions.
14 changes: 13 additions & 1 deletion brian2/codegen/functions/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
__all__ = ['Function', 'SimpleFunction', 'make_function']


class Function(object):
def __init__(self, pyfunc, sympy_func=None):
def __init__(self, pyfunc, sympy_func=None, arg_units=None,
return_unit=None):
self.pyfunc = pyfunc
self.sympy_func = sympy_func
if hasattr(pyfunc, '_arg_units'):
self._arg_units = pyfunc._arg_units
self._return_unit = pyfunc._return_unit
else:
if arg_units is None or return_unit is None:
raise ValueError(('The given Python function does not specify '
'how it deals with units, need to specify '
'"arg_units" and "return_unit"'))
self._arg_units = arg_units
self._return_unit = return_unit

'''
User-defined function to work with code generation
Expand Down
46 changes: 31 additions & 15 deletions brian2/codegen/functions/numpyfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class RandnFunction(Function):
The number of random numbers generated at a time.
'''
def __init__(self, N):
Function.__init__(self, pyfunc=randn)
Function.__init__(self, pyfunc=randn, arg_units=[], return_unit=1)
self.N = int(N)

def __call__(self):
Expand Down Expand Up @@ -89,7 +89,7 @@ class RandFunction(Function):
The number of random numbers generated at a time.
'''
def __init__(self, N):
Function.__init__(self, pyfunc=rand)
Function.__init__(self, pyfunc=rand, arg_units=[], return_unit=1)
self.N = int(N)

def __call__(self):
Expand Down Expand Up @@ -118,7 +118,7 @@ class BoolFunction(Function):
function to make sure it is interpreted correctly.
'''
def __init__(self):
Function.__init__(self, pyfunc=np.bool_)
Function.__init__(self, pyfunc=np.bool_, arg_units=[1], return_unit=1)

def __call__(self, value):
return np.bool_(value)
Expand All @@ -138,6 +138,7 @@ def code_cpp(self, language, var):
def on_compile_cpp(self, namespace, language, var):
pass


class FunctionWrapper(Function):
'''
Simple wrapper for functions that have exist both in numpy and C++
Expand All @@ -154,11 +155,24 @@ class FunctionWrapper(Function):
cpp_name : str, optional
The name of the corresponding function in C++, in case it is different.
sympy_func : sympy function, optional
The corresponding sympy function, if it exists.
The corresponding sympy function, if it exists.
arg_units : list of `Unit`, optional
The expected units of the arguments, ``None`` for arguments that can
have arbitrary units. Needs only to be specified if the `pyfunc`
function does not specify this already (e.g. via a `check_units`
decorator)
return_unit : `Unit` or callable, optional
The unit of the return value of this function. Either a fixed `Unit`,
or a function of the units of the arguments, e.g.
``lambda u : u **0.5`` for a square root function. Needs only to be
specified if the `pyfunc` function does not specify this already (e.g.
via a `check_units` decorator)
'''
# TODO: How to make this easily extendable for other languages?
def __init__(self, pyfunc, py_name=None, cpp_name=None, sympy_func=None):
Function.__init__(self, pyfunc, sympy_func)
def __init__(self, pyfunc, py_name=None, cpp_name=None, sympy_func=None,
arg_units=None, return_unit=None):
Function.__init__(self, pyfunc, sympy_func, arg_units=arg_units,
return_unit=return_unit)
if py_name is None:
py_name = pyfunc.__name__
self.py_name = py_name
Expand Down Expand Up @@ -206,11 +220,14 @@ def _get_default_functions():
'log10': FunctionWrapper(unitsafe.log10,
sympy_func=log10),
'sqrt': FunctionWrapper(np.sqrt,
sympy_func=sympy.functions.elementary.miscellaneous.sqrt),
sympy_func=sympy.functions.elementary.miscellaneous.sqrt,
arg_units=[None], return_unit=lambda u: u**0.5),
'ceil': FunctionWrapper(np.ceil,
sympy_func=sympy.functions.elementary.integers.ceiling),
sympy_func=sympy.functions.elementary.integers.ceiling,
arg_units=[None], return_unit=lambda u: u),
'floor': FunctionWrapper(np.floor,
sympy_func=sympy.functions.elementary.integers.floor),
sympy_func=sympy.functions.elementary.integers.floor,
arg_units=[None], return_unit=lambda u: u),
# numpy functions that have a different name in numpy and math.h
'arccos': FunctionWrapper(unitsafe.arccos,
cpp_name='acos',
Expand All @@ -220,16 +237,15 @@ def _get_default_functions():
sympy_func=sympy.functions.elementary.trigonometric.asin),
'arctan': FunctionWrapper(unitsafe.arctan,
cpp_name='atan',
sympy_func=sympy.functions.elementary.trigonometric.atan),
'power': FunctionWrapper(np.power,
cpp_name='pow',
sympy_func=sympy_power.Pow),
sympy_func=sympy.functions.elementary.trigonometric.atan),
'abs': FunctionWrapper(np.abs, py_name='abs',
cpp_name='fabs',
sympy_func=sympy.functions.elementary.complexes.Abs),
sympy_func=sympy.functions.elementary.complexes.Abs,
arg_units=[None], return_unit=lambda u: u),
'mod': FunctionWrapper(np.mod, py_name='mod',
cpp_name='fmod',
sympy_func=sympy_mod.Mod),
sympy_func=sympy_mod.Mod,
arg_units=[None, None], return_unit=lambda u,v : u),
'bool': BoolFunction()
}

Expand Down
2 changes: 1 addition & 1 deletion brian2/codegen/languages/cpp/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..base import Language, CodeObject
from ..templates import LanguageTemplater
from ...ast_parser import CPPNodeRenderer
from brian2.parsing.rendering import CPPNodeRenderer

logger = get_logger(__name__)
try:
Expand Down
2 changes: 1 addition & 1 deletion brian2/codegen/languages/python/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..base import Language, CodeObject
from ..templates import LanguageTemplater
from ...ast_parser import NumpyNodeRenderer
from brian2.parsing.rendering import NumpyNodeRenderer

__all__ = ['PythonLanguage', 'PythonCodeObject']

Expand Down
2 changes: 1 addition & 1 deletion brian2/codegen/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sympy.printing.str import StrPrinter

from .functions.numpyfunctions import DEFAULT_FUNCTIONS, log10
from .ast_parser import SympyNodeRenderer
from brian2.parsing.rendering import SympyNodeRenderer


def parse_statement(code):
Expand Down
4 changes: 2 additions & 2 deletions brian2/equations/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,10 +723,10 @@ def check_units(self, namespace, specifiers, additional_namespace=None):
continue

if eq.type == DIFFERENTIAL_EQUATION:
check_unit(eq.expr, self.units[var] / second,
check_unit(str(eq.expr), self.units[var] / second,
resolved_namespace, specifiers)
elif eq.type == STATIC_EQUATION:
check_unit(eq.expr, self.units[var],
check_unit(str(eq.expr), self.units[var],
resolved_namespace, specifiers)
else:
raise AssertionError('Unknown equation type: "%s"' % eq.type)
Expand Down
67 changes: 4 additions & 63 deletions brian2/equations/unitcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
sievert, katal, kgram, kgramme)

from brian2.codegen.translation import analyse_identifiers
from brian2.utils.stringtools import get_identifiers
from brian2.parsing.expressions import parse_expression_unit
from brian2.codegen.parsing import parse_statement
from brian2.core.specifiers import VariableSpecifier

Expand Down Expand Up @@ -84,64 +84,6 @@ def unit_from_string(unit_string):
return evaluated_unit


def unit_from_expression(expression, namespace, specifiers):
'''
Evaluates the unit for an expression in a given namespace.
Parameters
----------
expression : str or `Expression`
The expression to evaluate.
namespace : dict-like
The namespace of external variables.
specifiers : dict of `Specifier` objects
The information about the internal variables
Returns
-------
q : Quantity
The quantity or unit for the expression
Raises
------
KeyError
In case on of the identifiers cannot be resolved.
DimensionMismatchError
If an unit mismatch occurs during the evaluation.
'''

# Make sure we have a string expression
expression = str(expression)

# Create a mapping with all identifier names to either their actual
# value (for external identifiers) or their unit (for specifiers)
unit_namespace = {}
identifiers = get_identifiers(expression)
for name in identifiers:
if name in specifiers:
unit_namespace[name] = specifiers[name].unit
else:
# TODO: Should we add special support for user-defined functions
# or should we just assume that the Python version takes
# care of units?
# This raises an error if the identifier cannot be resolved
unit_namespace[name] = namespace[name]

try:
unit = eval(expression, unit_namespace)
except DimensionMismatchError as dim_ex:
raise DimensionMismatchError(('A unit mismatch occured while '
'evaluating the expression "%s": %s ' %
(expression, dim_ex.desc)),
*dim_ex.dims)

if not isinstance(unit, Quantity):
# It might be a dimensionless array
unit = Quantity(unit)

return unit


def check_unit(expression, unit, namespace, specifiers):
'''
Evaluates the unit for an expression in a given namespace.
Expand All @@ -166,8 +108,7 @@ def check_unit(expression, unit, namespace, specifiers):
--------
unit_from_expression
'''
expr_unit = unit_from_expression(expression, namespace, specifiers)

expr_unit = parse_expression_unit(expression, namespace, specifiers)
fail_for_dimension_mismatch(expr_unit, unit, ('Expression %s does not '
'have the expected units' %
expression))
Expand Down Expand Up @@ -225,8 +166,8 @@ def check_units_statements(code, namespace, specifiers):
else:
raise AssertionError('Unknown operator "%s"' % op)

expr_unit = unit_from_expression(expr, namespace, specs)
expr_unit = parse_expression_unit(expr, namespace, specs)

if var in specifiers:
fail_for_dimension_mismatch(specifiers[var].unit,
expr_unit,
Expand Down
4 changes: 2 additions & 2 deletions brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from brian2.equations.equations import (Equations, DIFFERENTIAL_EQUATION,
STATIC_EQUATION, PARAMETER)
from brian2.equations.refractory import add_refractoriness
from brian2.equations.unitcheck import unit_from_expression
from brian2.stateupdaters.base import StateUpdateMethod
from brian2.codegen.languages import PythonLanguage
from brian2.memory import allocate_array
Expand All @@ -19,6 +18,7 @@
StochasticVariable, Subexpression)
from brian2.core.spikesource import SpikeSource
from brian2.core.scheduler import Scheduler
from brian2.parsing.expressions import parse_expression_unit
from brian2.utils.logger import get_logger
from brian2.units.allunits import second
from brian2.units.fundamentalunits import (Quantity, Unit, have_same_dimensions,
Expand Down Expand Up @@ -68,7 +68,7 @@ def update_abstract_code(self, additional_namespace):
namespace = dict(self.group.namespace)
if additional_namespace is not None:
namespace.update(additional_namespace[1])
unit = unit_from_expression(ref, namespace, self.group.specifiers)
unit = parse_expression_unit(ref, namespace, self.group.specifiers)
if have_same_dimensions(unit, second):
self.abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref
elif have_same_dimensions(unit, Unit(1)):
Expand Down
Empty file added brian2/parsing/__init__.py
Empty file.
Loading

0 comments on commit a2fab76

Please sign in to comment.