Skip to content

Commit

Permalink
Merge pull request #675 from brian-team/fix_#674
Browse files Browse the repository at this point in the history
Provide an actual `Unit` (and not a `Quantity`) for the `Variable` object in `check_units_statements`
  • Loading branch information
mstimberg committed Apr 15, 2016
2 parents 5e52644 + fdd7438 commit a854416
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
6 changes: 4 additions & 2 deletions brian2/core/variables.py
Expand Up @@ -12,7 +12,7 @@
from brian2.utils.stringtools import get_identifiers, word_substitute
from brian2.units.fundamentalunits import (Quantity, Unit, DIMENSIONLESS,
fail_for_dimension_mismatch,
have_same_dimensions)
have_same_dimensions, get_unit)
from brian2.utils.logger import get_logger

from .base import weakproxy_with_fallback, device_override
Expand Down Expand Up @@ -130,7 +130,9 @@ class Variable(object):
def __init__(self, name, unit, owner=None, dtype=None, scalar=False,
constant=False, read_only=False, dynamic=False, array=False):
if not isinstance(unit, Unit):
if unit == 1:
if isinstance(unit, Quantity):
unit = get_unit(unit)
elif unit == 1:
unit = Unit(1)
else:
raise TypeError(('unit argument has to be a Unit object, was '
Expand Down
5 changes: 3 additions & 2 deletions brian2/equations/unitcheck.py
Expand Up @@ -3,7 +3,7 @@
'''
import re

from brian2.units.fundamentalunits import (Quantity, Unit,
from brian2.units.fundamentalunits import (get_unit, Unit,
fail_for_dimension_mismatch)

from brian2.parsing.expressions import parse_expression_unit
Expand Down Expand Up @@ -105,7 +105,8 @@ def check_units_statements(code, variables):
expected_unit))
elif varname in newly_defined:
# note the unit for later
variables[varname] = Variable(name=varname, unit=expr_unit,
variables[varname] = Variable(name=varname,
unit=get_unit(expr_unit),
scalar=False)
else:
raise AssertionError(('Variable "%s" is neither in the variables '
Expand Down
18 changes: 9 additions & 9 deletions brian2/tests/test_synapses.py
Expand Up @@ -884,21 +884,21 @@ def test_no_synapses():
@attr('standalone-compatible')
@with_setup(teardown=reinit_devices)
def test_summed_variable():
source = NeuronGroup(2, 'v : 1', threshold='v>1', reset='v=0')
source.v = 1.1 # will spike immediately
target = NeuronGroup(2, 'v : 1')
S = Synapses(source, target, '''w : 1
x : 1
v_post = x : 1 (summed)''', on_pre='x+=w',
source = NeuronGroup(2, 'v : volt', threshold='v>1*volt', reset='v=0*volt')
source.v = 1.1*volt # will spike immediately
target = NeuronGroup(2, 'v : volt')
S = Synapses(source, target, '''w : volt
x : volt
v_post = 2*x : volt (summed)''', on_pre='x+=w',
multisynaptic_index='k')
S.connect('i==j', n=2)
S.w['k == 0'] = 'i'
S.w['k == 1'] = 'i + 0.5'
S.w['k == 0'] = 'i*volt'
S.w['k == 1'] = '(i + 0.5)*volt'
net = Network(source, target, S)
net.run(1*ms)

# v of the target should be the sum of the two weights
assert_equal(target.v, np.array([0.5, 2.5]))
assert_equal(target.v, np.array([1.0, 5.0])*volt)


def test_summed_variable_errors():
Expand Down

0 comments on commit a854416

Please sign in to comment.