Skip to content

Commit

Permalink
Test more data types for run_args feature
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Aug 1, 2023
1 parent ec15a37 commit ec6f868
Showing 1 changed file with 49 additions and 42 deletions.
91 changes: 49 additions & 42 deletions brian2/tests/test_cpp_standalone.py
Expand Up @@ -661,41 +661,6 @@ def test_constant_replacement():
assert G.y[0] == 42.0


@pytest.mark.cpp_standalone
@pytest.mark.standalone_only
def test_change_parameter_without_recompile():
set_device("cpp_standalone", directory=None, with_output=False)
G = NeuronGroup(
10,
"""
x : 1
v : volt
""",
name="neurons",
)
G.x = np.arange(10)
G.v = np.arange(10) * volt

run(0 * ms)
assert array_equal(G.x, np.arange(10))
assert array_equal(G.v, np.arange(10) * volt)

device.run(run_args=["neurons.x=5", "neurons.v=3"])
assert array_equal(G.x, np.ones(10) * 5)
assert array_equal(G.v, np.ones(10) * 3 * volt)

ar = np.arange(10) * 2.0
ar.astype(G.x.dtype).tofile(os.path.join(device.project_dir, "init_values_x1.dat"))
ar.astype(G.v.dtype).tofile(os.path.join(device.project_dir, "init_values_v1.dat"))
device.run(
run_args=["neurons.v=init_values_v1.dat", "neurons.x=init_values_x1.dat"]
)
assert array_equal(G.x, ar)
assert array_equal(G.v, ar * volt)

reset_device()


@pytest.mark.cpp_standalone
@pytest.mark.standalone_only
def test_change_parameter_without_recompile():
Expand All @@ -708,17 +673,23 @@ def test_change_parameter_without_recompile():
G = NeuronGroup(
10,
"""
x : 1
v : volt
x : 1 (constant)
v : volt (constant)
n : integer (constant)
b : boolean (constant)
s = int(on_off(t))*stim(t, i) : amp
""",
name="neurons",
)
G.x = np.arange(10)
G.n = np.arange(10)
G.b = np.arange(10) % 2 == 0
G.v = np.arange(10) * volt
mon = StateMonitor(G, "s", record=True)
run(3 * defaultclock.dt)
assert array_equal(G.x, np.arange(10))
assert array_equal(G.n, np.arange(10))
assert array_equal(G.b, np.arange(10) % 2 == 0)
assert array_equal(G.v, np.arange(10) * volt)
assert_allclose(
mon.s.T / nA,
Expand All @@ -731,8 +702,18 @@ def test_change_parameter_without_recompile():
),
)

device.run(run_args=["neurons.x=5", "neurons.v=3", "on_off.values=True"])
device.run(
run_args=[
"neurons.x=5",
"neurons.v=3",
"neurons.n=17",
"neurons.b=True",
"on_off.values=True",
]
)
assert array_equal(G.x, np.ones(10) * 5)
assert array_equal(G.n, np.ones(10) * 17)
assert array_equal(G.b, np.ones(10, dtype=bool))
assert array_equal(G.v, np.ones(10) * 3 * volt)
assert_allclose(
mon.s.T / nA,
Expand All @@ -746,6 +727,10 @@ def test_change_parameter_without_recompile():
)
ar = np.arange(10) * 2.0
ar.astype(G.x.dtype).tofile(os.path.join(device.project_dir, "init_values_x1.dat"))
ar.astype(G.n.dtype).tofile(os.path.join(device.project_dir, "init_values_n1.dat"))
(np.arange(10) % 2 != 0).tofile(
os.path.join(device.project_dir, "init_values_b1.dat")
)
ar.astype(G.v.dtype).tofile(os.path.join(device.project_dir, "init_values_v1.dat"))
ar2 = 2 * np.arange(30).reshape(3, 10) * nA
ar2.astype(stim.values.dtype).tofile(
Expand All @@ -755,10 +740,14 @@ def test_change_parameter_without_recompile():
run_args=[
"neurons.v=init_values_v1.dat",
"neurons.x=init_values_x1.dat",
"neurons.b=init_values_b1.dat",
"neurons.n=init_values_n1.dat",
"stim.values=init_stim_values.dat",
]
)
assert array_equal(G.x, ar)
assert array_equal(G.n, ar)
assert array_equal(G.b, np.arange(10) % 2 != 0)
assert array_equal(G.v, ar * volt)
assert_allclose(
mon.s.T / nA,
Expand Down Expand Up @@ -803,17 +792,23 @@ def test_change_parameter_without_recompile_dict_syntax():
G = NeuronGroup(
10,
"""
x : 1
v : volt
x : 1 (constant)
n : integer (constant)
b : boolean (constant)
v : volt (constant)
s = int(on_off(t))*stim(t, i) : amp
""",
name="neurons",
)
G.x = np.arange(10)
G.n = np.arange(10)
G.b = np.arange(10) % 2 == 0
G.v = np.arange(10) * volt
mon = StateMonitor(G, "s", record=True)
run(3 * defaultclock.dt)
assert array_equal(G.x, np.arange(10))
assert array_equal(G.n, np.arange(10))
assert array_equal(G.b, np.arange(10) % 2 == 0)
assert array_equal(G.v, np.arange(10) * volt)
assert_allclose(
mon.s.T / nA,
Expand All @@ -825,8 +820,10 @@ def test_change_parameter_without_recompile_dict_syntax():
]
),
)
device.run(run_args={G.x: 5, G.v: 3 * volt, on_off: True})
device.run(run_args={G.x: 5, G.v: 3 * volt, G.n: 17, G.b: True, on_off: True})
assert array_equal(G.x, np.ones(10) * 5)
assert array_equal(G.n, np.ones(10) * 17)
assert array_equal(G.b, np.ones(10, dtype=bool))
assert array_equal(G.v, np.ones(10) * 3 * volt)
assert_allclose(
mon.s.T / nA,
Expand All @@ -840,8 +837,18 @@ def test_change_parameter_without_recompile_dict_syntax():
)
ar = np.arange(10) * 2.0
ar2 = 2 * np.arange(30).reshape(3, 10) * nA
device.run(run_args={G.x: ar, G.v: ar * volt, stim: ar2})
device.run(
run_args={
G.x: ar,
G.v: ar * volt,
G.n: ar,
G.b: np.arange(10) % 2 != 0,
stim: ar2,
}
)
assert array_equal(G.x, ar)
assert array_equal(G.n, ar)
assert array_equal(G.b, np.arange(10) % 2 != 0)
assert array_equal(G.v, ar * volt)
assert_allclose(
mon.s.T / nA,
Expand Down

0 comments on commit ec6f868

Please sign in to comment.