Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
rythorpe committed Jul 30, 2021
1 parent 5723ccf commit 02b6e60
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
16 changes: 12 additions & 4 deletions hnn_core/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,23 @@ def __init__(self, net, trial_idx=0):

self._build()

def _build(self):
def _build(self, test_rank=None):
"""Building the network in NEURON."""

global _CVODE, _PC
_create_parallel_context(expose_imem=self._expose_imem)

# used to set the rank while testing gid assignment in
# test_parallel_backend
if test_rank is None:
self._rank = _get_rank()
else:
self._rank = test_rank

# load mechanisms needs ParallelContext for get_rank
load_custom_mechanisms()

if _get_rank() == 0:
if self._rank == 0:
print('Building the NEURON model')

self._clear_last_network_objects()
Expand Down Expand Up @@ -340,7 +347,7 @@ def _build(self):
if len(self.net.rec_arrays) > 0:
self._record_extracellular()

if _get_rank() == 0:
if self._rank == 0:
print('[Done]')

def __enter__(self):
Expand All @@ -360,7 +367,7 @@ def _gid_assign(self):

self.net._update_cells() # updates net.n_cells

rank = _get_rank()
rank = self._rank
nhosts = _get_nhosts()

# round robin assignment of gids
Expand Down Expand Up @@ -575,6 +582,7 @@ def _clear_neuron_objects(self):

self._gid_list = list()
self._cells = list()
self._drive_cells = list()

def get_data_from_neuron(self):
"""Get copies of spike data that are pickleable"""
Expand Down
30 changes: 30 additions & 0 deletions hnn_core/tests/test_parallel_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hnn_core import MPIBackend, jones_2009_model, read_params
from hnn_core.dipole import simulate_dipole
from hnn_core.parallel_backends import requires_mpi4py, requires_psutil
from hnn_core.network_builder import NetworkBuilder


def _terminate_mpibackend(event, backend):
Expand All @@ -32,6 +33,35 @@ def _terminate_mpibackend(event, backend):
sleep(0.01)


def test_gid_assignment_across_ranks():
"""Test that gids are assigned without overlap across ranks"""
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
params.update({'N_pyr_x': 3,
'N_pyr_y': 3,
'tstop': 40,
't_evprox_1': 5,
't_evdist_1': 10,
't_evprox_2': 20,
'N_trials': 2})
net = jones_2009_model(params, add_drives_from_params=True)
n_drive_cells = {name: drive['n_drive_cells'] for name, drive in
net.external_drives.items()}
n_ranks = 3
n_drive_cells_instantiated = dict()
for rank in range(n_ranks):
net_builder = NetworkBuilder(net)
net_builder._build(test_rank=rank)
for drive_cell in net_builder._drive_cells:
drive_name = net.gid_to_type(drive_cell.gid)
if drive_name in n_drive_cells_instantiated:
n_drive_cells_instantiated[drive_name] += 1
else:
n_drive_cells_instantiated[drive_name] = 1
assert n_drive_cells == n_drive_cells_instantiated


# The purpose of this incremental mark is to avoid running the full length
# simulation when there are failures in previous (faster) tests. When a test
# in the sequence fails, all subsequent tests will be marked "xfailed" rather
Expand Down

0 comments on commit 02b6e60

Please sign in to comment.