Skip to content

Commit

Permalink
Check units for summed variable expressions
Browse files Browse the repository at this point in the history
Fixes #1346
  • Loading branch information
mstimberg committed Oct 11, 2021
1 parent 7e19957 commit 9692d85
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
16 changes: 16 additions & 0 deletions brian2/synapses/synapses.py
Expand Up @@ -130,6 +130,22 @@ def __init__(self, expression, target_varname, synapses, target,
name=synapses.name + '_summed_variable_' + target_varname,
template_kwds=template_kwds)

def before_run(self, run_namespace):
variables = self.group.resolve_all(self.expression.identifiers,
run_namespace)
rhs_unit = parse_expression_dimensions(self.expression.code, variables)
fail_for_dimension_mismatch(self.target_var,
# Using a quantity instead of dimensions
# here makes fail_for_dimension_mismatch
# state the dimensions as part of the error
# message
Quantity(1, dim=rhs_unit),
f"The target variable "
f"'{self.target_varname}' does not have "
f"the same dimensions as the right-hand "
f"side expression '{self.expression}'.")
super(SummedVariableUpdater, self).before_run(run_namespace)


class SynapticPathway(CodeRunner, Group):
'''
Expand Down
25 changes: 16 additions & 9 deletions brian2/tests/test_synapses.py
Expand Up @@ -1290,7 +1290,7 @@ def test_summed_variable_differing_group_size():
def test_summed_variable_errors():
G = NeuronGroup(10, '''dv/dt = -v / (10*ms) : volt
sub = 2*v : volt
p : volt''')
p : volt''', threshold='False', reset='')

# Using the (summed) flag for a differential equation or a parameter
with pytest.raises(ValueError):
Expand Down Expand Up @@ -1322,19 +1322,26 @@ def test_summed_variable_errors():

# Summed variable referring to an event-driven variable
with pytest.raises(EquationError) as ex:
Synapses(G, G, '''ds/dt = -s/(3*ms) : 1 (event-driven)
a = s : 1 (summed)''', on_pre='s += 1')
assert "'a'" in str(ex.value) and "'s'" in str(ex.value)
Synapses(G, G, '''ds/dt = -s/(3*ms) : volt (event-driven)
p_post = s : volt (summed)''', on_pre='s += 1*mV')
assert "'p_post'" in str(ex.value) and "'s'" in str(ex.value)

# Indirect dependency
with pytest.raises(EquationError) as ex:
Synapses(G, G, '''ds/dt = -s/(3*ms) : 1 (event-driven)
x = s : 1
y = x : 1
a = y : 1 (summed)''', on_pre='s += 1')
assert "'a'" in str(ex.value) and "'s'" in str(ex.value)
Synapses(G, G, '''ds/dt = -s/(3*ms) : volt (event-driven)
x = s : volt
y = x : volt
p_post = y : 1 (summed)''', on_pre='s += 1*mV')
assert "'p_post'" in str(ex.value) and "'s'" in str(ex.value)
assert "'x'" in str(ex.value) and "'y'" in str(ex.value)

with pytest.raises(BrianObjectException) as ex:
S = Synapses(G, G, '''y : siemens
p_post = y : volt (summed)''')
run(0 * ms)

assert isinstance(ex.value.__cause__, DimensionMismatchError)


@pytest.mark.codegen_independent
def test_multiple_summed_variables():
Expand Down

0 comments on commit 9692d85

Please sign in to comment.