Skip to content

Commit

Permalink
Merge pull request #1507 from brian-team/fix_external_function_test
Browse files Browse the repository at this point in the history
Do not assume that the test directory is writeable
  • Loading branch information
mstimberg committed Mar 7, 2024
2 parents c5b8719 + ec15c2f commit deb498f
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions brian2/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import shutil
import tempfile

import pytest
from numpy.testing import assert_equal
Expand Down Expand Up @@ -698,32 +700,37 @@ def foo(x, y):
def test_external_function_cpp_standalone():
set_device("cpp_standalone", directory=None)
this_dir = os.path.abspath(os.path.dirname(__file__))
with tempfile.TemporaryDirectory(prefix="brian_testsuite_") as tmpdir:
# copy the test function to the temporary directory
# this avoids issues with the file being in a directory that is not writable
shutil.copy(os.path.join(this_dir, "func_def_cpp.h"), tmpdir)
shutil.copy(os.path.join(this_dir, "func_def_cpp.cpp"), tmpdir)

@implementation(
"cpp",
"//all code in func_def_cpp.cpp",
headers=['"func_def_cpp.h"'],
include_dirs=[tmpdir],
sources=[os.path.join(tmpdir, "func_def_cpp.cpp")],
)
@check_units(x=volt, y=volt, result=volt)
def foo(x, y):
return x + y + 3 * volt

@implementation(
"cpp",
"//all code in func_def_cpp.cpp",
headers=['"func_def_cpp.h"'],
include_dirs=[this_dir],
sources=[os.path.join(this_dir, "func_def_cpp.cpp")],
)
@check_units(x=volt, y=volt, result=volt)
def foo(x, y):
return x + y + 3 * volt

G = NeuronGroup(
1,
"""
func = foo(x, y) : volt
x : volt
y : volt
""",
)
G.x = 1 * volt
G.y = 2 * volt
mon = StateMonitor(G, "func", record=True)
net = Network(G, mon)
net.run(defaultclock.dt)
assert mon[0].func == [6] * volt
G = NeuronGroup(
1,
"""
func = foo(x, y) : volt
x : volt
y : volt
""",
)
G.x = 1 * volt
G.y = 2 * volt
mon = StateMonitor(G, "func", record=True)
net = Network(G, mon)
net.run(defaultclock.dt)
assert mon[0].func == [6] * volt


@pytest.mark.codegen_independent
Expand Down

0 comments on commit deb498f

Please sign in to comment.