Skip to content

Commit

Permalink
Use the is_boolean_expression to determine whether the refractory per…
Browse files Browse the repository at this point in the history
…iod is a correct boolean expression. Also change the is_boolean_expression function to the common namespace/specififers signature
  • Loading branch information
Marcel Stimberg committed Jul 9, 2013
1 parent f9af0ab commit 316d8c9
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 34 deletions.
10 changes: 8 additions & 2 deletions brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StochasticVariable, Subexpression)
from brian2.core.spikesource import SpikeSource
from brian2.core.scheduler import Scheduler
from brian2.parsing.expressions import parse_expression_unit
from brian2.parsing.expressions import parse_expression_unit, is_boolean_expression
from brian2.utils.logger import get_logger
from brian2.units.allunits import second
from brian2.units.fundamentalunits import (Quantity, Unit, have_same_dimensions,
Expand Down Expand Up @@ -73,13 +73,19 @@ def update_abstract_code(self, additional_namespace):
if have_same_dimensions(unit, second):
self.abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref
elif have_same_dimensions(unit, Unit(1)):
if not is_boolean_expression(str(ref), namespace,
self.group.specifiers):
raise TypeError(('Refractory expression is dimensionless '
'but not a boolean value. It needs to '
'either evaluate to a timespan or to a '
'boolean value.'))
# boolean condition
# we have to be a bit careful here, we can't just use the given
# condition as it is, because we only want to *leave*
# refractoriness, based on the condition
self.abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref
else:
raise TypeError(('Refractory expression has to evaluate to a #'
raise TypeError(('Refractory expression has to evaluate to a '
'timespan or a boolean value, expression'
'"%s" has units %s instead') % (ref, unit))

Expand Down
40 changes: 24 additions & 16 deletions brian2/parsing/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
__all__ = ['is_boolean_expression',
'parse_expression_unit',]

def is_boolean_expression(expr, boolvars=None, boolfuncs=None):

def is_boolean_expression(expr, namespace, specifiers):
'''
Determines if an expression is of boolean type or not
Expand All @@ -21,10 +22,10 @@ def is_boolean_expression(expr, boolvars=None, boolfuncs=None):
expr : str
The expression to test
boolvars : set
The set of variables of boolean type.
boolfuncs : set
The set of functions which return booleans.
namespace : dict-like
The namespace of external variables.
specifiers : dict of `Specifier` objects
The information about the internal variables
Returns
-------
Expand Down Expand Up @@ -56,28 +57,35 @@ def is_boolean_expression(expr, boolvars=None, boolfuncs=None):
``not``, otherwise ``False``.
* Otherwise we return ``False``.
'''
if boolfuncs is None:
boolfuncs = set([])
if boolvars is None:
boolvars = set([])

boolvars.add('True')
boolvars.add('False')


# If we are working on a string, convert to the top level node
if isinstance(expr, str):
mod = ast.parse(expr, mode='eval')
expr = mod.body

if expr.__class__ is ast.BoolOp:
if all(is_boolean_expression(node, boolvars, boolfuncs) for node in expr.values):
if all(is_boolean_expression(node, namespace, specifiers)
for node in expr.values):
return True
else:
raise SyntaxError("Expression ought to be boolean but is not (e.g. 'x<y and 3')")
elif expr.__class__ is ast.Name:
return expr.id in boolvars
name = expr.id
if name in namespace:
value = namespace[name]
return value is True or value is False
elif name in specifiers:
return getattr(specifiers[name], 'is_bool', False)
else:
return name == 'True' or name == 'False'
elif expr.__class__ is ast.Call:
return expr.func.id in boolfuncs
name = expr.func.id
if name in namespace:
return getattr(namespace[name], '_returns_bool', False)
elif name in specifiers:
return getattr(specifiers[name], '_returns_bool', False)
else:
raise SyntaxError('Unknown function %s' % name)
elif expr.__class__ is ast.Compare:
return True
elif expr.__class__ is ast.UnaryOp:
Expand Down
59 changes: 47 additions & 12 deletions brian2/tests/test_parsing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
'''
Tests the brian2.parsing package
'''
from collections import namedtuple

from brian2.utils.stringtools import get_identifiers, deindent
from brian2.parsing.rendering import (NodeRenderer, NumpyNodeRenderer,
CPPNodeRenderer,
Expand Down Expand Up @@ -168,25 +170,58 @@ def test_abstract_code_dependencies():


def test_is_boolean_expression():
# dummy "specifier" class
Spec = namedtuple("Spec", ['is_bool'])

# dummy function object
class Func(object):
def __init__(self, returns_bool=False):
self._returns_bool = returns_bool

# namespace values / functions
a = True
b = False
c = 5
f = Func(returns_bool=True)
g = Func(returns_bool=False)

# specifier
s1 = Spec(is_bool=True)
s2 = Spec(is_bool=False)

namespace = {'a': a, 'b': b, 'c': c, 'f': f, 'g': g}
specifiers = {'s1': s1, 's2': s2}

EVF = [
(True, 'a or b', ['a', 'b'], []),
(True, 'True', [], []),
(True, 'a<b', [], []),
(True, 'not (a>=b)', [], []),
(False, 'a+b', [], []),
(True, 'f(x)', [], ['f']),
(False, 'f(x)', [], []),
(True, 'f(x) or a<b and c', ['c'], ['f']),
(True, 'a or b'),
(False, 'c'),
(False, 's2'),
(False, 'g(s1)'),
(True, 's2 > c'),
(True, 'c > 5'),
(True, 'True'),
(True, 'a<b'),
(True, 'not (a>=b)'),
(False, 'a+b'),
(True, 'f(c)'),
(False, 'g(c)'),
(True, 'f(c) or a<b and s1', ),
]
for expect, expr, vars, funcs in EVF:
ret_val = is_boolean_expression(expr, set(vars), set(funcs))
for expect, expr in EVF:
ret_val = is_boolean_expression(expr, namespace, specifiers)
if expect != ret_val:
raise AssertionError(('is_boolean_expression(%r) returned %s, '
'but was supposed to return %s') % (expr,
ret_val,
expect))
assert_raises(SyntaxError, is_boolean_expression, 'x<y and z')
assert_raises(SyntaxError, is_boolean_expression, 'a or b')
assert_raises(SyntaxError, is_boolean_expression, 'a<b and c',
namespace, specifiers)
assert_raises(SyntaxError, is_boolean_expression, 'a or foo',
namespace, specifiers)
assert_raises(SyntaxError, is_boolean_expression, 'ot a', # typo
namespace, specifiers)
assert_raises(SyntaxError, is_boolean_expression, 'g(c) and f(a)',
namespace, specifiers)


def test_parse_expression_unit():
Expand Down
29 changes: 25 additions & 4 deletions brian2/tests/test_refractory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numpy.testing.utils import assert_equal, assert_allclose
from numpy.testing.utils import assert_equal, assert_allclose, assert_raises

from brian2 import *
from brian2.equations.refractory import add_refractoriness
Expand All @@ -21,11 +21,16 @@ def test_add_refractoriness():
def test_refractoriness_variables():
# Try a quantity, a string evaluating to a quantity an an explicit boolean
# condition -- all should do the same thing
for ref_time in [5*ms, '5*ms', '(t-lastspike) < 5*ms']:
for ref_time in [5*ms, '5*ms', '(t-lastspike) < 5*ms',
'(t-lastspike) < ref', 'ref', 'ref_no_unit*ms']:
G = NeuronGroup(1, '''
dv/dt = 100*Hz : 1 (unless-refractory)
dw/dt = 100*Hz : 1
ref : second
ref_no_unit : 1
''', threshold='v>1', reset='v=0;w=0', refractory=ref_time)
G.ref = 5*ms
G.ref_no_unit = 5
# It should take 10ms to reach the threshold, then v should stay at 0
# for 5ms, while w continues to increase
mon = StateMonitor(G, ['v', 'w'], record=True)
Expand All @@ -46,11 +51,16 @@ def test_refractoriness_variables():
def test_refractoriness_threshold():
# Try a quantity, a string evaluating to a quantity an an explicit boolean
# condition -- all should do the same thing
for ref_time in [10*ms, '10*ms', '(t-lastspike) <= 10*ms']:
for ref_time in [10*ms, '10*ms', '(t-lastspike) <= 10*ms',
'(t-lastspike) <= ref', 'ref', 'ref_no_unit*ms']:
G = NeuronGroup(1, '''
dv/dt = 200*Hz : 1
ref : second
ref_no_unit : 1
''', threshold='not_refractory and (v > 1)',
reset='v=0', refractory=ref_time)
G.ref = 10*ms
G.ref_no_unit = 10
# The neuron should spike after 5ms but then not spike for the next
# 10ms. The state variable should continue to integrate so there should
# be a spike after 15ms
Expand All @@ -60,7 +70,18 @@ def test_refractoriness_threshold():
assert_allclose(spike_mon.t, [4.9, 15] * ms)


def test_refractoriness_types():
# make sure that using a wrong type of refractoriness does not work
assert_raises(TypeError, lambda: NeuronGroup(1, '', refractory='3*Hz'))
assert_raises(TypeError, lambda: NeuronGroup(1, 'ref: Hz',
refractory='ref'))
assert_raises(TypeError, lambda: NeuronGroup(1, '', refractory='3'))
assert_raises(TypeError, lambda: NeuronGroup(1, 'ref: 1',
refractory='ref'))


if __name__ == '__main__':
test_add_refractoriness()
test_refractoriness_variables()
test_refractoriness_threshold()
test_refractoriness_threshold()
test_refractoriness_types()

0 comments on commit 316d8c9

Please sign in to comment.