From 7dee01fc2b1cf8c27b3fb71ed7edf357563c9f43 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Mon, 8 Jul 2013 19:31:43 +0200 Subject: [PATCH 1/5] Start implementing support for boolean variables --- brian2/core/specifiers.py | 74 ++++++++++++++++++++++++++--------- brian2/equations/unitcheck.py | 6 +-- brian2/groups/group.py | 6 ++- brian2/groups/neurongroup.py | 3 +- brian2/parsing/expressions.py | 1 - 5 files changed, 66 insertions(+), 24 deletions(-) diff --git a/brian2/core/specifiers.py b/brian2/core/specifiers.py index ff006f4c1..1921ad5e9 100644 --- a/brian2/core/specifiers.py +++ b/brian2/core/specifiers.py @@ -3,11 +3,12 @@ TODO: have a single global dtype rather than specify for each variable? ''' +import numpy as np from brian2.units.allunits import second from brian2.utils.stringtools import get_identifiers -from brian2.units.fundamentalunits import is_scalar_type +from brian2.units.fundamentalunits import is_scalar_type, have_same_dimensions __all__ = ['Specifier', 'VariableSpecifier', @@ -63,16 +64,26 @@ class VariableSpecifier(Specifier): constant: bool, optional Whether the value of this variable can change during a run. Defaults to ``False``. + is_bool: bool, optional + Whether this is a boolean variable (also implies it is dimensionless). + Defaults to ``False`` See Also -------- Value ''' - def __init__(self, name, unit, scalar=True, constant=False): + def __init__(self, name, unit, scalar=True, constant=False, is_bool=False): Specifier.__init__(self, name) #: The variable's unit. self.unit = unit + #: Whether this is a boolean variable + self.is_bool = is_bool + + if is_bool: + if not have_same_dimensions(unit, 1): + raise ValueError('Boolean variables can only be dimensionless') + #: Whether the value is a scalar self.scalar = scalar @@ -111,10 +122,15 @@ class Value(VariableSpecifier): defined for every neuron (``False``). Defaults to ``True``. constant: bool, optional Whether the value of this variable can change during a run. Defaults - to ``False``. + to ``False``. + is_bool: bool, optional + Whether this is a boolean variable (also implies it is dimensionless). + Defaults to ``False`` + ''' - def __init__(self, name, unit, dtype, scalar=True, constant=False): - VariableSpecifier.__init__(self, name, unit, scalar, constant) + def __init__(self, name, unit, dtype, scalar=True, constant=False, + is_bool=False): + VariableSpecifier.__init__(self, name, unit, scalar, constant, is_bool) #: The dtype used for storing the variable. self.dtype = dtype @@ -146,6 +162,7 @@ def __repr__(self): scalar=repr(self.scalar), constant=repr(self.constant)) + ############################################################################### # Concrete classes that are used as specifiers in practice. ############################################################################### @@ -177,8 +194,11 @@ def __init__(self, name, unit, dtype, value): self.value = value scalar = is_scalar_type(value) - - Value.__init__(self, name, unit, dtype, scalar, constant=True) + + is_bool = value is True or value is False + + Value.__init__(self, name, unit, dtype, scalar, constant=True, + is_bool=is_bool) def get_value(self): return self.value @@ -235,28 +255,33 @@ class AttributeValue(ReadOnlyValue): to be an attribute of `obj`. constant : bool, optional Whether the attribute's value is constant during a run. - + is_bool: bool, optional + Whether this is a boolean variable (also implies it is dimensionless). + Defaults to ``False`` Raises ------ AttributeError If `obj` does not have an attribute `attribute`. ''' - def __init__(self, name, unit, dtype, obj, attribute, constant=False): + def __init__(self, name, unit, dtype, obj, attribute, constant=False, + is_bool=False): if not hasattr(obj, attribute): raise AttributeError(('Object %r does not have an attribute %r, ' 'providing the value for %r') % (obj, attribute, name)) + + value = getattr(obj, attribute) + scalar = is_scalar_type(value) + + is_bool = value is True or value is False - scalar = is_scalar_type(getattr(obj, attribute)) - - Value.__init__(self, name, unit, dtype, scalar, constant) + Value.__init__(self, name, unit, dtype, scalar, constant, is_bool) #: A reference to the object storing the variable's value self.obj = obj #: The name of the attribute storing the variable's value self.attribute = attribute - def get_value(self): return getattr(self.obj, self.attribute) @@ -303,9 +328,18 @@ class ArrayVariable(Value): variable. constant : bool, optional Whether the variable's value is constant during a run. + is_bool: bool, optional + Whether this is a boolean variable (also implies it is dimensionless). + Defaults to ``False`` ''' - def __init__(self, name, unit, dtype, array, index, constant=False): - Value.__init__(self, name, unit, dtype, scalar=False, constant=constant) + def __init__(self, name, unit, dtype, array, index, constant=False, + is_bool=False): + if is_bool: + if not dtype == np.bool: + raise ValueError(('Boolean variables have to be stored with ' + 'boolean dtype')) + Value.__init__(self, name, unit, dtype, scalar=False, + constant=constant, is_bool=is_bool) #: The reference to the array storing the data for the variable. self.array = array #: The name for the array used in generated code @@ -353,9 +387,13 @@ class Subexpression(Value): namespace : dict The namespace dictionary, containing identifiers for all the external variables/functions used in the expression + is_bool: bool, optional + Whether this is a boolean variable (also implies it is dimensionless). + Defaults to ``False`` ''' - def __init__(self, name, unit, dtype, expr, specifiers, namespace): - Value.__init__(self, name, unit, dtype, scalar=False) + def __init__(self, name, unit, dtype, expr, specifiers, namespace, + is_bool=False): + Value.__init__(self, name, unit, dtype, scalar=False, is_bool=is_bool) #: The expression defining the static equation. self.expr = expr.strip() #: The identifiers used in the expression @@ -424,4 +462,4 @@ def __init__(self, name, iterate_all=True): def __repr__(self): return '%s(name=%r, iterate_all=%r)' % (self.__class__.__name__, self.name, - self.iterate_all) + self.iterate_all) diff --git a/brian2/equations/unitcheck.py b/brian2/equations/unitcheck.py index 0c2881fd9..a566d3beb 100644 --- a/brian2/equations/unitcheck.py +++ b/brian2/equations/unitcheck.py @@ -26,7 +26,7 @@ def unit_from_string(unit_string): ''' Returns the unit that results from evaluating a string like "siemens / metre ** 2", allowing for the special string "1" to signify - dimensionless units. + dimensionless units and the string "bool" to mark a boolean variable. Parameters ---------- @@ -35,8 +35,8 @@ def unit_from_string(unit_string): Returns ------- - u : Unit - The resulting unit + u : Unit or bool + The resulting unit or ``True`` for a boolean parameter. Raises ------ diff --git a/brian2/groups/group.py b/brian2/groups/group.py index 3d3604764..1f39dfea2 100644 --- a/brian2/groups/group.py +++ b/brian2/groups/group.py @@ -78,7 +78,11 @@ def __setattr__(self, name, val): object.__setattr__(self, name, val) elif name in self.specifiers: spec = self.specifiers[name] - fail_for_dimension_mismatch(val, spec.unit, + if spec.is_bool: + unit = Unit(1) + else: + unit = spec.unit + fail_for_dimension_mismatch(val, unit, 'Incorrect units for setting %s' % name) spec.set_value(val) elif len(name) and name[-1]=='_' and name[:-1] in self.specifiers: diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index b9ef01000..66710a15c 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -68,7 +68,8 @@ def update_abstract_code(self, additional_namespace): namespace = dict(self.group.namespace) if additional_namespace is not None: namespace.update(additional_namespace[1]) - unit = parse_expression_unit(ref, namespace, self.group.specifiers) + unit = parse_expression_unit(str(ref), namespace, + self.group.specifiers) if have_same_dimensions(unit, second): self.abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref elif have_same_dimensions(unit, Unit(1)): diff --git a/brian2/parsing/expressions.py b/brian2/parsing/expressions.py index 03743e14f..a57cd7886 100644 --- a/brian2/parsing/expressions.py +++ b/brian2/parsing/expressions.py @@ -211,7 +211,6 @@ def parse_expression_unit(expr, namespace, specifiers): if isinstance(expr, basestring): mod = ast.parse(expr, mode='eval') expr = mod.body - if expr.__class__ is ast.Name: name = expr.id if name in specifiers: From f9af0ab9eec9de3eb488720a7fffe7dac10d6ff4 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 9 Jul 2013 12:32:32 +0200 Subject: [PATCH 2/5] Finalize support for boolean variables, remove the bool function. --- brian2/codegen/functions/numpyfunctions.py | 31 +--------------------- brian2/equations/equations.py | 20 ++++++++++---- brian2/equations/refractory.py | 5 ++-- brian2/groups/group.py | 6 +---- brian2/groups/neurongroup.py | 24 ++++++++++------- brian2/tests/test_equations.py | 4 +-- 6 files changed, 37 insertions(+), 53 deletions(-) diff --git a/brian2/codegen/functions/numpyfunctions.py b/brian2/codegen/functions/numpyfunctions.py index 77f495844..039ba0108 100644 --- a/brian2/codegen/functions/numpyfunctions.py +++ b/brian2/codegen/functions/numpyfunctions.py @@ -111,34 +111,6 @@ def on_compile_cpp(self, namespace, language, var): pass -class BoolFunction(Function): - ''' A specifier for the `bool` function. To make sure that they are - interpreted as boolean values, references to state variables that are - meant as boolean (e.g. ``not_refractory``) should be wrapped in this - function to make sure it is interpreted correctly. - ''' - def __init__(self): - Function.__init__(self, pyfunc=np.bool_, arg_units=[1], return_unit=1) - - def __call__(self, value): - return np.bool_(value) - - def code_cpp(self, language, var): - - support_code = ''' - double _bool(float value) - { - return value == 0 ? false : true; - } - ''' - - return {'support_code': support_code, - 'hashdefine_code': ''} - - def on_compile_cpp(self, namespace, language, var): - pass - - class FunctionWrapper(Function): ''' Simple wrapper for functions that have exist both in numpy and C++ @@ -245,8 +217,7 @@ def _get_default_functions(): 'mod': FunctionWrapper(np.mod, py_name='mod', cpp_name='fmod', sympy_func=sympy_mod.Mod, - arg_units=[None, None], return_unit=lambda u,v : u), - 'bool': BoolFunction() + arg_units=[None, None], return_unit=lambda u,v : u) } return functions diff --git a/brian2/equations/equations.py b/brian2/equations/equations.py index 7b339b6d2..dc71fbc89 100644 --- a/brian2/equations/equations.py +++ b/brian2/equations/equations.py @@ -9,10 +9,9 @@ import sympy from pyparsing import (Group, ZeroOrMore, OneOrMore, Optional, Word, CharsNotIn, Combine, Suppress, restOfLine, LineEnd, ParseException) -import sympy from brian2.codegen.parsing import sympy_to_str, str_to_sympy -from brian2.units.fundamentalunits import DimensionMismatchError +from brian2.units.fundamentalunits import Unit, have_same_dimensions from brian2.units.allunits import second from brian2.utils.logger import get_logger @@ -189,7 +188,9 @@ def parse_string_equations(eqns): # Convert unit string to Unit object unit = unit_from_string(eq_content['unit']) - + is_bool = unit is True + if is_bool: + unit = Unit(1) expression = eq_content.get('expression', None) if not expression is None: # Replace multiple whitespaces (arising from joining multiline @@ -198,7 +199,8 @@ def parse_string_equations(eqns): expression = Expression(p.sub(' ', expression)) flags = list(eq_content.get('flags', [])) - equation = SingleEquation(eq_type, identifier, unit, expression, flags) + equation = SingleEquation(eq_type, identifier, unit, is_bool=is_bool, + expr=expression, flags=flags) if identifier in equations: raise EquationError('Duplicate definition of variable "%s"' % @@ -225,6 +227,9 @@ class SingleEquation(object): The variable that is defined by this equation. unit : Unit The unit of the variable + is_bool : bool, optional + Whether this variable is a boolean variable (implies it is + dimensionless as well). Defaults to ``False``. expr : `Expression`, optional The expression defining the variable (or ``None`` for parameters). flags: list of str, optional @@ -233,10 +238,14 @@ class SingleEquation(object): context. ''' - def __init__(self, type, varname, unit, expr=None, flags=None): + def __init__(self, type, varname, unit, is_bool=False, expr=None, + flags=None): self.type = type self.varname = varname self.unit = unit + self.is_bool = is_bool + if is_bool and not have_same_dimensions(unit, 1): + raise ValueError('Boolean variables are necessarily dimensionless.') self.expr = expr if flags is None: self.flags = [] @@ -314,6 +323,7 @@ def _repr_pretty_(self, p, cycle): def _repr_latex_(self): return '$' + sympy.latex(self) + '$' + class Equations(collections.Mapping): """ Container that stores equations from which models can be created. diff --git a/brian2/equations/refractory.py b/brian2/equations/refractory.py index 796ae9923..4b372113a 100644 --- a/brian2/equations/refractory.py +++ b/brian2/equations/refractory.py @@ -60,13 +60,14 @@ def add_refractoriness(eqs): new_code = 'not_refractory*(' + eq.expr.code + ')' new_equations.append(SingleEquation(DIFFERENTIAL_EQUATION, eq.varname, eq.unit, - Expression(new_code), + expr=Expression(new_code), flags=eq.flags)) else: new_equations.append(eq) # add new parameters - new_equations.append(SingleEquation(PARAMETER, 'not_refractory', Unit(1))) + new_equations.append(SingleEquation(PARAMETER, 'not_refractory', Unit(1), + is_bool=True)) new_equations.append(SingleEquation(PARAMETER, 'lastspike', second)) return Equations(new_equations) diff --git a/brian2/groups/group.py b/brian2/groups/group.py index 1f39dfea2..3d3604764 100644 --- a/brian2/groups/group.py +++ b/brian2/groups/group.py @@ -78,11 +78,7 @@ def __setattr__(self, name, val): object.__setattr__(self, name, val) elif name in self.specifiers: spec = self.specifiers[name] - if spec.is_bool: - unit = Unit(1) - else: - unit = spec.unit - fail_for_dimension_mismatch(val, unit, + fail_for_dimension_mismatch(val, spec.unit, 'Incorrect units for setting %s' % name) spec.set_value(val) elif len(name) and name[-1]=='_' and name[:-1] in self.specifiers: diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index 66710a15c..db38fe68e 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -77,7 +77,7 @@ def update_abstract_code(self, additional_namespace): # we have to be a bit careful here, we can't just use the given # condition as it is, because we only want to *leave* # refractoriness, based on the condition - self.abstract_code = 'not_refractory = bool(not_refractory) or not (%s)\n' % ref + self.abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref else: raise TypeError(('Refractory expression has to evaluate to a #' 'timespan or a boolean value, expression' @@ -291,22 +291,26 @@ def __len__(self): def _allocate_memory(self, dtype=None): # Allocate memory (TODO: this should be refactored somewhere at some point) - arrayvarnames = set(eq.varname for eq in self.equations.itervalues() if - eq.type in (DIFFERENTIAL_EQUATION, - PARAMETER)) + arrays = {} - for name in arrayvarnames: + for eq in self.equations.itervalues(): + if eq.type == STATIC_EQUATION: + # nothing to do + continue + name = eq.varname if isinstance(dtype, dict): curdtype = dtype[name] else: curdtype = dtype if curdtype is None: curdtype = brian_prefs['core.default_scalar_dtype'] - arrays[name] = allocate_array(self.N, dtype=curdtype) + if eq.is_bool: + arrays[name] = allocate_array(self.N, dtype=np.bool) + else: + arrays[name] = allocate_array(self.N, dtype=curdtype) logger.debug("NeuronGroup memory allocated successfully.") return arrays - def runner(self, code, when=None, name=None): ''' Returns a `CodeRunner` that runs abstract code in the groups namespace @@ -358,14 +362,16 @@ def _create_specifiers(self): array.dtype, array, '_neuron_idx', - constant)}) + constant, + eq.is_bool)}) elif eq.type == STATIC_EQUATION: s.update({eq.varname: Subexpression(eq.varname, eq.unit, brian_prefs['core.default_scalar_dtype'], str(eq.expr), s, - self.namespace)}) + self.namespace, + eq.is_bool)}) else: raise AssertionError('Unknown type of equation: ' + eq.eq_type) diff --git a/brian2/tests/test_equations.py b/brian2/tests/test_equations.py index b625f3f5e..bd2a63555 100644 --- a/brian2/tests/test_equations.py +++ b/brian2/tests/test_equations.py @@ -205,9 +205,9 @@ def test_construction_errors(): v = 2 * t/second * volt : volt''')) eqs = [SingleEquation(DIFFERENTIAL_EQUATION, 'v', volt, - Expression('-v / tau')), + expr=Expression('-v / tau')), SingleEquation(STATIC_EQUATION, 'v', volt, - Expression('2 * t/second * volt')) + expr=Expression('2 * t/second * volt')) ] assert_raises(EquationError, lambda: Equations(eqs)) From 316d8c95941735c0eda4140683efc0b430d0070d Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 9 Jul 2013 14:19:12 +0200 Subject: [PATCH 3/5] Use the is_boolean_expression to determine whether the refractory period is a correct boolean expression. Also change the is_boolean_expression function to the common namespace/specififers signature --- brian2/groups/neurongroup.py | 10 ++++-- brian2/parsing/expressions.py | 40 +++++++++++++--------- brian2/tests/test_parsing.py | 59 ++++++++++++++++++++++++++------- brian2/tests/test_refractory.py | 29 +++++++++++++--- 4 files changed, 104 insertions(+), 34 deletions(-) diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index db38fe68e..a6c6d7786 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -18,7 +18,7 @@ StochasticVariable, Subexpression) from brian2.core.spikesource import SpikeSource from brian2.core.scheduler import Scheduler -from brian2.parsing.expressions import parse_expression_unit +from brian2.parsing.expressions import parse_expression_unit, is_boolean_expression from brian2.utils.logger import get_logger from brian2.units.allunits import second from brian2.units.fundamentalunits import (Quantity, Unit, have_same_dimensions, @@ -73,13 +73,19 @@ def update_abstract_code(self, additional_namespace): if have_same_dimensions(unit, second): self.abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref elif have_same_dimensions(unit, Unit(1)): + if not is_boolean_expression(str(ref), namespace, + self.group.specifiers): + raise TypeError(('Refractory expression is dimensionless ' + 'but not a boolean value. It needs to ' + 'either evaluate to a timespan or to a ' + 'boolean value.')) # boolean condition # we have to be a bit careful here, we can't just use the given # condition as it is, because we only want to *leave* # refractoriness, based on the condition self.abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref else: - raise TypeError(('Refractory expression has to evaluate to a #' + raise TypeError(('Refractory expression has to evaluate to a ' 'timespan or a boolean value, expression' '"%s" has units %s instead') % (ref, unit)) diff --git a/brian2/parsing/expressions.py b/brian2/parsing/expressions.py index a57cd7886..5d3beebdb 100644 --- a/brian2/parsing/expressions.py +++ b/brian2/parsing/expressions.py @@ -12,7 +12,8 @@ __all__ = ['is_boolean_expression', 'parse_expression_unit',] -def is_boolean_expression(expr, boolvars=None, boolfuncs=None): + +def is_boolean_expression(expr, namespace, specifiers): ''' Determines if an expression is of boolean type or not @@ -21,10 +22,10 @@ def is_boolean_expression(expr, boolvars=None, boolfuncs=None): expr : str The expression to test - boolvars : set - The set of variables of boolean type. - boolfuncs : set - The set of functions which return booleans. + namespace : dict-like + The namespace of external variables. + specifiers : dict of `Specifier` objects + The information about the internal variables Returns ------- @@ -56,28 +57,35 @@ def is_boolean_expression(expr, boolvars=None, boolfuncs=None): ``not``, otherwise ``False``. * Otherwise we return ``False``. ''' - if boolfuncs is None: - boolfuncs = set([]) - if boolvars is None: - boolvars = set([]) - - boolvars.add('True') - boolvars.add('False') - + # If we are working on a string, convert to the top level node if isinstance(expr, str): mod = ast.parse(expr, mode='eval') expr = mod.body if expr.__class__ is ast.BoolOp: - if all(is_boolean_expression(node, boolvars, boolfuncs) for node in expr.values): + if all(is_boolean_expression(node, namespace, specifiers) + for node in expr.values): return True else: raise SyntaxError("Expression ought to be boolean but is not (e.g. 'x=b)', [], []), - (False, 'a+b', [], []), - (True, 'f(x)', [], ['f']), - (False, 'f(x)', [], []), - (True, 'f(x) or a c'), + (True, 'c > 5'), + (True, 'True'), + (True, 'a=b)'), + (False, 'a+b'), + (True, 'f(c)'), + (False, 'g(c)'), + (True, 'f(c) or a Date: Tue, 9 Jul 2013 14:35:03 +0200 Subject: [PATCH 4/5] Allow "bool" as a unit in user-defined equations and make sure that it cannot be used for differential equations. --- brian2/equations/equations.py | 2 ++ brian2/equations/unitcheck.py | 6 +++++- brian2/tests/test_equations.py | 9 +++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/brian2/equations/equations.py b/brian2/equations/equations.py index dc71fbc89..6b5d94609 100644 --- a/brian2/equations/equations.py +++ b/brian2/equations/equations.py @@ -191,6 +191,8 @@ def parse_string_equations(eqns): is_bool = unit is True if is_bool: unit = Unit(1) + if eq_type == DIFFERENTIAL_EQUATION: + raise EquationError('Differential equations cannot be boolean') expression = eq_content.get('expression', None) if not expression is None: # Replace multiple whitespaces (arising from joining multiline diff --git a/brian2/equations/unitcheck.py b/brian2/equations/unitcheck.py index a566d3beb..c40b9d096 100644 --- a/brian2/equations/unitcheck.py +++ b/brian2/equations/unitcheck.py @@ -59,7 +59,11 @@ def unit_from_string(unit_string): # Special case: dimensionless unit if unit_string == '1': return Unit(1, dim=DIMENSIONLESS) - + + # Another special case: boolean variable + if unit_string == 'bool': + return True + # Check first whether the expression evaluates at all, using only base units try: evaluated_unit = eval(unit_string, namespace) diff --git a/brian2/tests/test_equations.py b/brian2/tests/test_equations.py index bd2a63555..580cf6f90 100644 --- a/brian2/tests/test_equations.py +++ b/brian2/tests/test_equations.py @@ -115,12 +115,15 @@ def test_parse_equations(): dge/dt = -ge / tau_ge : volt I = sin(2 * pi * f * t) : volt f : Hz (constant) + b : bool ''') - assert len(eqs.keys()) == 4 + assert len(eqs.keys()) == 5 assert 'v' in eqs and eqs['v'].type == DIFFERENTIAL_EQUATION assert 'ge' in eqs and eqs['ge'].type == DIFFERENTIAL_EQUATION assert 'I' in eqs and eqs['I'].type == STATIC_EQUATION assert 'f' in eqs and eqs['f'].type == PARAMETER + assert 'b' in eqs and eqs['b'].type == PARAMETER + assert not eqs['f'].is_bool and eqs['b'].is_bool assert get_dimensions(eqs['v'].unit) == volt.dim assert get_dimensions(eqs['ge'].unit) == volt.dim assert get_dimensions(eqs['I'].unit) == volt.dim @@ -140,12 +143,14 @@ def test_parse_equations(): x = 2 * t : 1''', '''dv/dt = -v / tau : 1 : volt x = 2 * t : 1''', - ''' dv/dt = -v / tau : 2 * volt'''] + ''' dv/dt = -v / tau : 2 * volt''', + 'dv/dt = v / second : bool'] for error_eqs in parse_error_eqs: assert_raises((ValueError, EquationError), lambda: parse_string_equations(error_eqs)) + def test_correct_replacements(): ''' Test replacing variables via keyword arguments ''' From b61cca6c4fc9f9e5044c91c505c9df9eb28f29b2 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 9 Jul 2013 16:12:40 +0200 Subject: [PATCH 5/5] Handle subexpressions better. --- brian2/codegen/translation.py | 68 +++++++++++++++++++++++---------- brian2/equations/unitcheck.py | 5 ++- brian2/groups/group.py | 20 ++++++---- brian2/tests/test_codegen.py | 20 +++++++++- brian2/tests/test_refractory.py | 7 +++- 5 files changed, 88 insertions(+), 32 deletions(-) diff --git a/brian2/codegen/translation.py b/brian2/codegen/translation.py index 6f7d236c6..14320bc67 100644 --- a/brian2/codegen/translation.py +++ b/brian2/codegen/translation.py @@ -15,6 +15,7 @@ * The language to translate to ''' import re +import collections from numpy import float64 @@ -25,10 +26,12 @@ from .statements import Statement from .parsing import parse_statement -__all__ = ['translate', 'make_statements'] +__all__ = ['translate', 'make_statements', 'analyse_identifiers', + 'get_identifiers_recursively'] DEBUG = False + class LineInfo(object): ''' A helper class, just used to store attributes. @@ -37,8 +40,11 @@ def __init__(self, **kwds): for k, v in kwds.iteritems(): setattr(self, k, v) + # TODO: This information should go somewhere else, I guess +STANDARD_IDENTIFIERS = set(['and', 'or', 'not', 'True', 'False']) + -def analyse_identifiers(code, known=None): +def analyse_identifiers(code, specifiers, recursive=False): ''' Analyses a code string (sequence of statements) to find all identifiers by type. @@ -52,38 +58,59 @@ def analyse_identifiers(code, known=None): Parameters ---------- - code : str The code string, a sequence of statements one per line. - known : list, set, None - A list or set of known (already created) variables. + specifiers : dict of `Specifier`, set of names + Specifiers for the model variables or a set of known names + recursive : bool, optional + Whether to recurse down into subexpressions (defaults to ``False``). Returns ------- - newly_defined : set A set of variables that are created by the code block. used_known : set - A set of variables that are used and already known, a subset of the ``known`` parameter. - dependent : set - A set of variables which are used by the code block but not defined by it and not - previously known. If this set is nonempty it may indicate an error, for example. + A set of variables that are used and already known, a subset of the + ``known`` parameter. + unknown : set + A set of variables which are used by the code block but not defined by + it and not previously known. Should correspond to variables in the + external namespace. ''' - if known is None: - known = set() - known = set(known) - # TODO: This information should go somewhere else, I guess - standard_identifiers = set(['and', 'or', 'not', 'True', 'False']) - known |= standard_identifiers - specifiers = dict((k, Value(k, 1, float64)) for k in known) + if isinstance(specifiers, collections.Mapping): + known = set(specifiers.keys()) + else: + known = set(specifiers) + specifiers = dict((k, Value(k, 1, float64)) for k in known) + + known |= STANDARD_IDENTIFIERS stmts = make_statements(code, specifiers, float64) defined = set(stmt.var for stmt in stmts if stmt.op==':=') - allids = set(get_identifiers(code)) + if recursive: + if not isinstance(specifiers, collections.Mapping): + raise TypeError('Have to specify a specifiers dictionary.') + allids = get_identifiers_recursively(code, specifiers) + else: + allids = get_identifiers(code) dependent = allids.difference(defined, known) - used_known = allids.intersection(known) - standard_identifiers + used_known = allids.intersection(known) - STANDARD_IDENTIFIERS + return defined, used_known, dependent +def get_identifiers_recursively(expr, specifiers): + ''' + Gets all the identifiers in a code, recursing down into subexpressions. + ''' + identifiers = get_identifiers(expr) + for name in set(identifiers): + if name in specifiers and isinstance(specifiers[name], Subexpression): + s_identifiers = get_identifiers_recursively(specifiers[name].expr, + specifiers) + identifiers |= s_identifiers + return identifiers + + def make_statements(code, specifiers, dtype): ''' Turn a series of abstract code statements into Statement objects, inferring @@ -116,7 +143,7 @@ def make_statements(code, specifiers, dtype): # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read - line.read = set(get_identifiers(expr)) + line.read = get_identifiers_recursively(expr, specifiers) if DEBUG: print 'PARSED STATEMENTS:' @@ -153,6 +180,7 @@ def make_statements(code, specifiers, dtype): # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. + #subexpressions = get_all_subexpressions() subexpressions = dict((name, val) for name, val in specifiers.items() if isinstance(val, Subexpression)) if DEBUG: print 'SUBEXPRESSIONS:', subexpressions.keys() diff --git a/brian2/equations/unitcheck.py b/brian2/equations/unitcheck.py index c40b9d096..75c789804 100644 --- a/brian2/equations/unitcheck.py +++ b/brian2/equations/unitcheck.py @@ -3,8 +3,8 @@ ''' import re -from brian2.units.fundamentalunits import Quantity, Unit,\ - fail_for_dimension_mismatch, DimensionMismatchError +from brian2.units.fundamentalunits import (Quantity, Unit, + fail_for_dimension_mismatch) from brian2.units.fundamentalunits import DIMENSIONLESS from brian2.units.allunits import (metre, meter, second, amp, kelvin, mole, candle, kilogram, radian, steradian, hertz, @@ -117,6 +117,7 @@ def check_unit(expression, unit, namespace, specifiers): 'have the expected units' % expression)) + def check_units_statements(code, namespace, specifiers): ''' Check the units for a series of statements. Setting a model variable has to diff --git a/brian2/groups/group.py b/brian2/groups/group.py index 3d3604764..c4488f565 100644 --- a/brian2/groups/group.py +++ b/brian2/groups/group.py @@ -98,19 +98,22 @@ def _create_codeobj(group, name, code, additional_namespace=None, if check_units: # Resolve the namespace, resulting in a dictionary containing only the # external variables that are needed by the code -- keep the units for - # the unit checks - _, _, unknown = analyse_identifiers(code, group.specifiers.keys()) + # the unit checks + # Note that here, in contrast to the namespace resolution below, we do + # not need to recursively descend into subexpressions. For unit + # checking, we only need to know the units of the subexpressions, + # not what variables they refer to + _, _, unknown = analyse_identifiers(code, group.specifiers) resolved_namespace = group.namespace.resolve_all(unknown, additional_namespace, strip_units=False) check_units_statements(code, resolved_namespace, group.specifiers) - # Get the namespace without units - _, used_known, unknown = analyse_identifiers(code, group.specifiers.keys()) - resolved_namespace = group.namespace.resolve_all(unknown, - additional_namespace) - + # Determine the identifiers that were used + _, used_known, unknown = analyse_identifiers(code, group.specifiers, + recursive=True) + # Only pass the specifiers that are actually used specifiers = {} for name in used_known: @@ -121,6 +124,9 @@ def _create_codeobj(group, name, code, additional_namespace=None, for spec in template.specifiers: specifiers[spec] = group.specifiers[spec] + resolved_namespace = group.namespace.resolve_all(unknown, + additional_namespace) + return group.language.create_codeobj(name, code, resolved_namespace, diff --git a/brian2/tests/test_codegen.py b/brian2/tests/test_codegen.py index 23bf3afdf..6a05bdfa5 100644 --- a/brian2/tests/test_codegen.py +++ b/brian2/tests/test_codegen.py @@ -1,7 +1,9 @@ import numpy as np -from numpy.testing import assert_raises, assert_equal -from brian2.codegen.translation import analyse_identifiers +from brian2.codegen.translation import analyse_identifiers, get_identifiers_recursively +from brian2.core.specifiers import Subexpression, Specifier +from brian2.units.fundamentalunits import Unit + def test_analyse_identifiers(): ''' @@ -20,5 +22,19 @@ def test_analyse_identifiers(): assert dependent==set(['e', 'f']) +def test_get_identifiers_recursively(): + ''' + Test finding identifiers including subexpressions. + ''' + specifiers = {} + specifiers['sub1'] = Subexpression('sub1', Unit(1), np.float32, 'sub2 * z', + specifiers, {}) + specifiers['sub2'] = Subexpression('sub2', Unit(1), np.float32, '5 + y', + specifiers, {}) + specifiers['x'] = Specifier('x') + identifiers = get_identifiers_recursively('_x = sub1 + x', specifiers) + assert identifiers == set(['x', '_x', 'y', 'z', 'sub1', 'sub2']) + if __name__ == '__main__': test_analyse_identifiers() + test_get_identifiers_recursively() diff --git a/brian2/tests/test_refractory.py b/brian2/tests/test_refractory.py index 57a9f3e24..1c459dc0e 100644 --- a/brian2/tests/test_refractory.py +++ b/brian2/tests/test_refractory.py @@ -22,13 +22,18 @@ def test_refractoriness_variables(): # Try a quantity, a string evaluating to a quantity an an explicit boolean # condition -- all should do the same thing for ref_time in [5*ms, '5*ms', '(t-lastspike) < 5*ms', + 'time_since_spike < 5*ms', 'ref_subexpression', '(t-lastspike) < ref', 'ref', 'ref_no_unit*ms']: G = NeuronGroup(1, ''' dv/dt = 100*Hz : 1 (unless-refractory) dw/dt = 100*Hz : 1 ref : second ref_no_unit : 1 - ''', threshold='v>1', reset='v=0;w=0', refractory=ref_time) + time_since_spike = t - lastspike : second + ref_subexpression = (t - lastspike) < ref : bool + ''', + threshold='v>1', reset='v=0;w=0', + refractory=ref_time) G.ref = 5*ms G.ref_no_unit = 5 # It should take 10ms to reach the threshold, then v should stay at 0