Skip to content

Commit

Permalink
update conftest function names
Browse files Browse the repository at this point in the history
  • Loading branch information
kohl-carmen authored and blakecaldwell committed Nov 30, 2020
1 parent 8ff6c1a commit fa9c666
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 28 deletions.
9 changes: 5 additions & 4 deletions hnn_core/tests/conftest.py
Expand Up @@ -71,9 +71,10 @@ def pytest_runtest_setup(item):


@pytest.fixture(scope='module')
def run_hnn_core():
def _run_hnn_core(backend=None, n_procs=None, n_jobs=1, reduced=False,
record_vsoma=False, record_isoma=False, postproc=True):
def run_hnn_core_fixture():
def _run_hnn_core_fixture(backend=None, n_procs=None, n_jobs=1,
reduced=False, record_vsoma=False,
record_isoma=False, postproc=True):
hnn_core_root = op.dirname(hnn_core.__file__)

# default params
Expand Down Expand Up @@ -109,4 +110,4 @@ def _run_hnn_core(backend=None, n_procs=None, n_jobs=1, reduced=False,
postproc=postproc)

return dpls, net
return _run_hnn_core
return _run_hnn_core_fixture
24 changes: 13 additions & 11 deletions hnn_core/tests/test_dipole.py
Expand Up @@ -14,7 +14,7 @@
matplotlib.use('agg')


def test_dipole(tmpdir, run_hnn_core):
def test_dipole(tmpdir, run_hnn_core_fixture):
"""Test dipole object."""
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
Expand Down Expand Up @@ -47,11 +47,12 @@ def test_dipole(tmpdir, run_hnn_core):
dipole_avg = average_dipoles([dipole_avg, dipole_read])

# test postproc
dpls_raw, net = run_hnn_core(backend='joblib', n_jobs=1, reduced=True,
record_isoma=True, record_vsoma=True,
postproc=False)
dpls, _ = run_hnn_core(backend='joblib', n_jobs=1, reduced=True,
record_isoma=True, record_vsoma=True, postproc=True)
dpls_raw, net = run_hnn_core_fixture(backend='joblib', n_jobs=1,
reduced=True, record_isoma=True,
record_vsoma=True, postproc=False)
dpls, _ = run_hnn_core_fixture(backend='joblib', n_jobs=1, reduced=True,
record_isoma=True, record_vsoma=True,
postproc=True)
with pytest.raises(AssertionError):
assert_allclose(dpls[0].data['agg'], dpls_raw[0].data['agg'])

Expand Down Expand Up @@ -83,15 +84,16 @@ def test_dipole_simulation():


@requires_mpi4py
def test_cell_response_backends(run_hnn_core):
def test_cell_response_backends(run_hnn_core_fixture):
"""Test cell_response outputs across backends."""

# reduced simulation has n_trials=2
trial_idx, n_trials, gid = 0, 2, 7
_, joblib_net = run_hnn_core(backend='joblib', n_jobs=1, reduced=True,
record_isoma=True, record_vsoma=True)
_, mpi_net = run_hnn_core(backend='mpi', n_jobs=1, reduced=True,
record_isoma=True, record_vsoma=True)
_, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1,
reduced=True, record_isoma=True,
record_vsoma=True)
_, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True,
record_isoma=True, record_vsoma=True)
n_times = len(joblib_net.cell_response.times)

assert len(joblib_net.cell_response.vsoma) == n_trials
Expand Down
27 changes: 14 additions & 13 deletions hnn_core/tests/test_parallel_backends.py
Expand Up @@ -26,21 +26,21 @@ class TestParallelBackends():
dpls_reduced_default = None
dpls_reduced_joblib = None

def test_run_default(self, run_hnn_core):
def test_run_default(self, run_hnn_core_fixture):
"""Test consistency between default backend simulation and master"""
global dpls_reduced_default
dpls_reduced_default, _ = run_hnn_core(None, reduced=True)
dpls_reduced_default, _ = run_hnn_core_fixture(None, reduced=True)
# test consistency across all parallel backends for multiple trials
assert_raises(AssertionError, assert_array_equal,
dpls_reduced_default[0].data['agg'],
dpls_reduced_default[1].data['agg'])

def test_run_joblibbackend(self, run_hnn_core):
def test_run_joblibbackend(self, run_hnn_core_fixture):
"""Test consistency between joblib backend simulation with master"""
global dpls_reduced_default, dpls_reduced_joblib

dpls_reduced_joblib, _ = run_hnn_core(backend='joblib',
n_jobs=2, reduced=True)
dpls_reduced_joblib, _ = run_hnn_core_fixture(backend='joblib',
n_jobs=2, reduced=True)

for trial_idx in range(len(dpls_reduced_default)):
assert_array_equal(dpls_reduced_default[trial_idx].data['agg'],
Expand All @@ -55,24 +55,25 @@ def test_mpi_nprocs(self):
assert backend.n_procs > 1

@requires_mpi4py
def test_run_mpibackend(self, run_hnn_core):
def test_run_mpibackend(self, run_hnn_core_fixture):
"""Test running a MPIBackend on reduced model"""
global dpls_reduced_default, dpls_reduced_mpi
dpls_reduced_mpi, _ = run_hnn_core(backend='mpi', reduced=True)
dpls_reduced_mpi, _ = run_hnn_core_fixture(backend='mpi', reduced=True)
for trial_idx in range(len(dpls_reduced_default)):
# account for rounding error incured during MPI parallelization
assert_allclose(dpls_reduced_default[trial_idx].data['agg'],
dpls_reduced_mpi[trial_idx].data['agg'], rtol=0,
atol=1e-14)

@requires_mpi4py
def test_run_mpibackend_oversubscribed(self, run_hnn_core):
def test_run_mpibackend_oversubscribed(self, run_hnn_core_fixture):
"""Test running MPIBackend with oversubscribed number of procs"""
oversubscribed = round(cpu_count() * 1.5)
run_hnn_core(backend='mpi', n_procs=oversubscribed, reduced=True)
run_hnn_core_fixture(backend='mpi', n_procs=oversubscribed,
reduced=True)

@pytest.mark.parametrize("backend", ['mpi', 'joblib'])
def test_compare_hnn_core(self, run_hnn_core, backend, n_jobs=1):
def test_compare_hnn_core(self, run_hnn_core_fixture, backend, n_jobs=1):
"""Test hnn-core does not break."""
# small snippet of data on data branch for now. To be deleted
# later. Data branch should have only commit so it does not
Expand All @@ -89,7 +90,7 @@ def test_compare_hnn_core(self, run_hnn_core, backend, n_jobs=1):
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)

dpls, net = run_hnn_core(params, backend)
dpls, net = run_hnn_core_fixture(params, backend)
dpl = dpls[0]

# write the dipole to a file and compare
Expand Down Expand Up @@ -122,15 +123,15 @@ def test_compare_hnn_core(self, run_hnn_core, backend, n_jobs=1):
# there are no dependencies if this unit tests fails; no need to be in
# class marked incremental
@requires_mpi4py
def test_mpi_failure(run_hnn_core):
def test_mpi_failure(run_hnn_core_fixture):
"""Test that an MPI failure is handled and messages are printed"""
# this MPI paramter will cause a MPI job to fail
environ["OMPI_MCA_btl"] = "self"

with pytest.warns(UserWarning) as record:
with io.StringIO() as buf, redirect_stdout(buf):
with pytest.raises(RuntimeError, match="MPI simulation failed"):
run_hnn_core(backend='mpi', reduced=True)
run_hnn_core_fixture(backend='mpi', reduced=True)
stdout = buf.getvalue()

assert "MPI processes are unable to reach each other" in stdout
Expand Down

0 comments on commit fa9c666

Please sign in to comment.