Skip to content

Commit

Permalink
MAINT: Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jasmainak authored and rythorpe committed Aug 5, 2021
1 parent 071f846 commit 6430b2a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
38 changes: 21 additions & 17 deletions hnn_core/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,20 +294,27 @@ def __init__(self, net, trial_idx=0):
if len(self.net.rec_arrays) > 0:
self._expose_imem = True

self._rank = 0

self._build()

def _build(self, test_rank=None):
"""Building the network in NEURON."""
def _build(self, rank=None):
"""Building the network in NEURON.
Parameters
----------
rank : int | None
If not None, override the rank set
automatically using Neuron. Used for testing.
"""

global _CVODE, _PC
_create_parallel_context(expose_imem=self._expose_imem)

# used to set the rank while testing gid assignment in
# test_parallel_backend
if test_rank is None:
if rank is None:
self._rank = _get_rank()
else:
self._rank = test_rank
self._rank = rank

# load mechanisms needs ParallelContext for get_rank
load_custom_mechanisms()
Expand Down Expand Up @@ -349,17 +356,13 @@ def _build(self, test_rank=None):
if self._rank == 0:
print('[Done]')

# this happens on EACH node
# creates self._gid_list for THIS node
def _gid_assign(self):

"""Assign cell IDs to this node"""
self.net._update_cells() # updates net.n_cells

rank = self._rank
nhosts = _get_nhosts()

n_hosts = _get_nhosts()
# round robin assignment of gids
for gid in range(rank, self.net.n_cells, nhosts):
for gid in range(self._rank, self.net.n_cells, n_hosts):
# set the cell gid
self._gid_list.append(gid)

Expand All @@ -376,12 +379,9 @@ def _gid_assign(self):
self._gid_list.append(src_gid)
else:
src_gids = list(self.net.gid_ranges[drive['name']])
for gid_idx in range(rank, len(src_gids), nhosts):
for gid_idx in range(self._rank, len(src_gids), n_hosts):
self._gid_list.append(src_gids[gid_idx])

for gid in self._gid_list:
_PC.set_gid2node(gid, rank)

# extremely important to get the gids in the right order
self._gid_list.sort()

Expand All @@ -395,6 +395,10 @@ def _create_cells_and_drives(self, threshold, record_vsoma=False,
These drives are spike SOURCES but cells are also targets.
External inputs are not targets.
"""

for gid in self._gid_list:
_PC.set_gid2node(gid, self._rank)

# loop through ALL gids
# have to loop over self._gid_list, since this is what we got
# on this rank (MPI)
Expand Down
3 changes: 1 addition & 2 deletions hnn_core/tests/test_parallel_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_gid_assignment():
params = read_params(params_fname)
params.update({'N_pyr_x': 3,
'N_pyr_y': 3,
'tstop': 40,
't_evprox_1': 5,
't_evdist_1': 10,
't_evprox_2': 20,
Expand All @@ -52,7 +51,7 @@ def test_gid_assignment():
n_drive_cells_instantiated = dict()
for rank in range(n_ranks):
net_builder = NetworkBuilder(net)
net_builder._build(test_rank=rank)
net_builder._build(rank=rank)
for drive_cell in net_builder._drive_cells:
drive_name = net.gid_to_type(drive_cell.gid)
if drive_name in n_drive_cells_instantiated:
Expand Down

0 comments on commit 6430b2a

Please sign in to comment.