Skip to content

Commit

Permalink
Merge pull request #798 from brian-team/fix_#796
Browse files Browse the repository at this point in the history
Simplify `group_variable_set` template in C++ standalone
  • Loading branch information
mstimberg committed Jan 10, 2017
2 parents 9d12d8f + a3e15a9 commit 191c2d7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 46 deletions.
4 changes: 1 addition & 3 deletions brian2/core/variables.py
Expand Up @@ -981,14 +981,12 @@ def set_with_expression(self, item, code, run_namespace, check_units=True):
# TODO: Have an additional argument to avoid going through the index
# array for situations where iterate_all could be used
from brian2.codegen.codeobject import create_runner_codeobj
from brian2.devices.device import get_default_codeobject_class, get_device
from brian2.devices.device import get_default_codeobject_class

group_index_var = get_device().get_array_name(variables['_group_idx'])
codeobj = create_runner_codeobj(self.group,
abstract_code,
'group_variable_set',
additional_variables=variables,
template_kwds={'_group_index_var': group_index_var},
check_units=check_units,
run_namespace=run_namespace,
codeobj_class=get_default_codeobject_class('codegen.string_expression_target'))
Expand Down
Expand Up @@ -11,7 +11,7 @@
for(int _idx_group_idx=0; _idx_group_idx<_num_group_idx; _idx_group_idx++)
{
// vector code
const int _idx = {{_group_index_var}}[_idx_group_idx];
const int _idx = {{_group_idx}}[_idx_group_idx];
const int _vectorisation_idx = _idx;
{{vector_code|autoindent}}
}
Expand Down
115 changes: 73 additions & 42 deletions brian2/tests/test_neurongroup.py
Expand Up @@ -855,6 +855,7 @@ def test_state_variables():
assert_raises(TypeError, lambda: G.v.__isub__('string'))


@attr('codegen-independent')
def test_state_variable_access():
G = NeuronGroup(10, 'v:volt')
G.v = np.arange(10) * volt
Expand Down Expand Up @@ -897,71 +898,100 @@ def test_state_variable_access_strings():
assert_raises(DimensionMismatchError, lambda: G.v['v >= 3'])
assert_raises(DimensionMismatchError, lambda: G.v['v >= 3*second'])


@attr('standalone-compatible')
@with_setup(teardown=reinit_devices)
def test_state_variable_set_strings():
# Instead of overwriting the same variable over and over, we have one
# variable for each assignment so that we can test everything in the end
# for standalone.
G = NeuronGroup(10, '''v1 : volt
v2 : volt
v3 : volt
v4 : volt
v5 : volt
v6 : volt
v7 : volt
v8 : volt
v9 : volt
v10 : volt
v11 : volt
dv_ref/dt = -v_ref/(10*ms) : 1 (unless refractory)''',
threshold='v_ref>1', reset='v_ref=1', refractory=1 * ms)
# Setting with strings
# --------------------
# String value referring to i
G.v = '2*i*volt'
assert_equal(G.v[:], 2*np.arange(10)*volt)
G.v1 = '2*i*volt'
# String value referring to i
G.v[:5] = '3*i*volt'
assert_equal(G.v[:],
np.array([0, 3, 6, 9, 12, 10, 12, 14, 16, 18])*volt)

G.v = np.arange(10) * volt
G.v1[:5] = '3*i*volt'

# Conditional write variable
G.v_ref = '2*i'
assert_equal(G.v_ref[:], 2*np.arange(10))

G.v2 = np.arange(10)*volt
# String value referring to a state variable
G.v = '2*v'
assert_equal(G.v[:], 2*np.arange(10)*volt)
G.v[:5] = '2*v'
assert_equal(G.v[:],
np.array([0, 4, 8, 12, 16, 10, 12, 14, 16, 18])*volt)
G.v2 = '2*v2'
G.v2[:5] = '2*v2'

G.v = np.arange(10) * volt
G.v3 = np.arange(10) * volt
# String value referring to state variables, i, and an external variable
ext = 5*volt
G.v = 'v + ext + (N + i)*volt'
assert_equal(G.v[:], 2*np.arange(10)*volt + 15*volt)
G.v3 = 'v3 + ext + (N + i)*volt'

G.v = np.arange(10) * volt
G.v[:5] = 'v + ext + (N + i)*volt'
assert_equal(G.v[:],
np.array([15, 17, 19, 21, 23, 5, 6, 7, 8, 9])*volt)
G.v4 = np.arange(10) * volt
G.v4[:5] = 'v4 + ext + (N + i)*volt'

G.v = 'v + randn()*volt' # only check that it doesn't raise an error
G.v[:5] = 'v + randn()*volt' # only check that it doesn't raise an error
G.v5 = 'v5 + randn()*volt' # only check that it doesn't raise an error
G.v5[:5] = 'v5 + randn()*volt' # only check that it doesn't raise an error

G.v = np.arange(10) * volt
G.v6 = np.arange(10) * volt
# String index using a random number
G.v['rand() <= 1'] = 0*mV
assert_equal(G.v[:], np.zeros(10)*volt)
G.v6['rand() <= 1'] = 0*mV

G.v = np.arange(10) * volt
G.v7 = np.arange(10) * volt
# String index referring to i and setting to a scalar value
G.v['i>=5'] = 0*mV
assert_equal(G.v[:], np.array([0, 1, 2, 3, 4, 0, 0, 0, 0, 0])*volt)
G.v7['i>=5'] = 0*mV

G.v8[:5] = np.arange(5) * volt
# String index referring to a state variable
G.v['v<3*volt'] = 0*mV
assert_equal(G.v[:], np.array([0, 0, 0, 3, 4, 0, 0, 0, 0, 0])*volt)
G.v8['v8<3*volt'] = 0*mV
# String index referring to state variables, i, and an external variable
ext = 2*volt
G.v['v>=ext and i==(N-6)'] = 0*mV
assert_equal(G.v[:], np.array([0, 0, 0, 3, 0, 0, 0, 0, 0, 0])*volt)
G.v8['v8>=ext and i==(N-6)'] = 0*mV

G.v = np.arange(10) * volt
G.v9 = np.arange(10) * volt
# Strings for both condition and values
G.v['i>=5'] = 'v*2'
assert_equal(G.v[:], np.array([0, 1, 2, 3, 4, 10, 12, 14, 16, 18])*volt)
G.v['v>=5*volt'] = 'i*volt'
assert_equal(G.v[:], np.arange(10)*volt)
G.v['i<=5'] = '(100 + rand())*volt'
assert_equal(G.v[6:], np.arange(4)*volt + 6*volt) # unchanged
assert all(G.v[:6] >= 100*volt)
assert all(G.v[:6] <= 101*volt)
assert np.var(G.v_[:6]) > 0
G.v9['i>=5'] = 'v9*2'
G.v9['v9>=5*volt'] = 'i*volt'

G.v10 = np.arange(10)*volt
G.v10['i<=5'] = '(100 + rand())*volt'

# string assignment to scalars
G.v11[0] = '1*volt'
G.v11[1] = '(1 + i)*volt'
G.v11[2] = 'v11 + 3*volt'
G.v11[3] = 'inf*volt'
G.v11[4] = 'rand()*volt'
run(0*ms)
assert_equal(G.v1[:],
np.array([0, 3, 6, 9, 12, 10, 12, 14, 16, 18])*volt)
assert_equal(G.v_ref[:], 2 * np.arange(10))
assert_equal(G.v2[:],
np.array([0, 4, 8, 12, 16, 10, 12, 14, 16, 18])*volt)
assert_equal(G.v3[:], 2 * np.arange(10) * volt + 15 * volt)
assert_equal(G.v4[:],
np.array([15, 17, 19, 21, 23, 5, 6, 7, 8, 9])*volt)
assert_equal(G.v6[:], np.zeros(10) * volt)
assert_equal(G.v7[:], np.array([0, 1, 2, 3, 4, 0, 0, 0, 0, 0]) * volt)
assert_equal(G.v8[:], np.array([0, 0, 0, 3, 0, 0, 0, 0, 0, 0]) * volt)
assert_equal(G.v9[:], np.arange(10) * volt)
assert_equal(G.v10[6:], np.arange(4)*volt + 6*volt) # unchanged
assert all(G.v10[:6] >= 100*volt)
assert all(G.v10[:6] <= 101*volt)
assert np.var(G.v10_[:6]) > 0
assert_equal(G.v11[:3], [1, 2, 3]*volt)
assert np.isinf(G.v11_[3])

@attr('codegen-independent')
def test_unknown_state_variables():
Expand Down Expand Up @@ -1471,6 +1501,7 @@ def test_no_code():
test_state_variables()
test_state_variable_access()
test_state_variable_access_strings()
test_state_variable_set_strings()
test_unknown_state_variables()
test_subexpression()
test_subexpression_with_constant()
Expand Down

0 comments on commit 191c2d7

Please sign in to comment.