Skip to content

Commit

Permalink
ranaming, refactoring and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cjayb committed Nov 12, 2020
1 parent a4d2ffc commit 5c8a800
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 81 deletions.
2 changes: 1 addition & 1 deletion hnn_core/dipole.py
Expand Up @@ -42,7 +42,7 @@ def simulate_dipole(net, n_trials=None):

if n_trials is not None:
net.params['N_trials'] = n_trials
net.instantiate_feeds() # need to redo these if already calc'd
net._instantiate_feeds() # need to redo these if n_trials changed!
else:
n_trials = net.params['N_trials']

Expand Down
5 changes: 1 addition & 4 deletions hnn_core/mpi_child.py
Expand Up @@ -100,13 +100,10 @@ def run(self, params):
from hnn_core import Network
from hnn_core.parallel_backends import _clone_and_simulate

prng_seedcore_initial = params['prng_*']

net = Network(params)
sim_data = []
for trial_idx in range(params['N_trials']):
single_sim_data = _clone_and_simulate(net, trial_idx,
prng_seedcore_initial)
single_sim_data = _clone_and_simulate(net, trial_idx)

# go ahead and append trial data for each rank, though
# only rank 0 has data that should be sent back to MPIBackend
Expand Down
52 changes: 29 additions & 23 deletions hnn_core/network.py
Expand Up @@ -171,11 +171,6 @@ def __init__(self, params):
# place to keep this information
self.gid_ranges = dict()

# dict with keys 'common', 'evdist1', ...
# XXX needs better semantics, here to get first pass working
# list (len == n_trials) of dicts (keys: src_type) of event times
self._tmp_feed_event_times = []

# Create empty spikes object
self.spikes = Spikes()

Expand All @@ -187,23 +182,21 @@ def __init__(self, params):
'L5_pyramidal',
]
self.feedname_list = [] # no feeds defined yet
self.trial_event_times = [] # list of len == n_trials

# cell position lists, also will give counts: must be known
# by ALL nodes
# XXX structure of pos_dict determines all downstream inferences of
# contents of pos_dict determines all downstream inferences of
# cell counts, real and artificial
self.pos_dict = _create_cell_coords(n_pyr_x=self.params['N_pyr_x'],
n_pyr_y=self.params['N_pyr_y'],
zdiff=1307.4)
# XXX not strictly necessary here (_add_external_feeds calls it blw.)
# but perhaps should be encouraged on adding new cells?
# Every time pos_dict is updated, gid_ranges must be updated too
self._update_gid_ranges()

# set n_cells, EXCLUDING Artificial ones
self.n_cells = sum(len(self.pos_dict[src]) for src in
self.cellname_list)

# XXX The legacy code in HNN-GUI _always_ defines 2 'common' and 5
# The legacy code in HNN-GUI _always_ defines 2 'common' and 5
# 'unique' feeds. They are added here for backwards compatibility
# until a new handling of external NetworkDrives's is completed.
self._add_external_feeds()
Expand All @@ -226,37 +219,47 @@ def _add_external_feeds(self):
# Global number of external inputs ... automatic counting
# makes more sense
# p_unique represent ext inputs that are going to go to each cell
self.p_common, self.p_unique = create_pext(self.params,
self.params['tstop'])
self.n_common_feeds = len(self.p_common)
self._p_common, self._p_unique = create_pext(self.params,
self.params['tstop'])
self._n_common_feeds = len(self._p_common)

# 'position' the artificial cells arbitrarily in the origin of the
# network grid. Non-biophysical cell placement is irrelevant
# However, they must be present to be included in gid_ranges!
origin = self.pos_dict['origin']

# COMMON FEEDS
self.pos_dict['common'] = [origin for i in range(self.n_common_feeds)]
self.pos_dict['common'] = [origin for i in range(self._n_common_feeds)]

# UNIQUE FEEDS
for key in self.p_unique.keys():
for key in self._p_unique.keys():
# create the pos_dict for all the sources
self.pos_dict[key] = [origin for i in range(self.n_cells)]

# Now add the names of the feeds to a list
self.feedname_list.append('common')
# grab the keys for the unique set of inputs and sort the names
# append them to the src list along with the number of cells
unique_keys = sorted(self.p_unique.keys())
unique_keys = sorted(self._p_unique.keys())
self.feedname_list += unique_keys

self._update_gid_ranges() # add feeds -> update the gid_ranges
self.instantiate_feeds()
# Every time pos_dict is updated, gid_ranges must be updated too
self._update_gid_ranges()

# Create the feed dynamics (event_times)
self._instantiate_feeds()

def instantiate_feeds(self):
def _instantiate_feeds(self):
'''Creates event_time vectors for all feeds and all trials
NB this must be a separate method because dipole.py:simulate_dipole
accepts an n_trials-argument, which overrides the N_trials-parameter
used at intialisation time. The good news is that only the event_times
need to be recalculated, all the GIDs etc remain the same.
'''
# each trial needs unique event time vectors

self._tmp_feed_event_times = [] # reset if called again
self.trial_event_times = [] # reset if called again from dipole.py
n_trials = self.params['N_trials']

cur_params = self.params.copy() # these get mangled below!
Expand All @@ -267,6 +270,7 @@ def instantiate_feeds(self):
cur_params[param_key] =\
prng_seedcore_initial[param_key] + trial_idx
# needs to be re-run to create the dicts going into ExtFeed
# the only thing changing is the initial seed
p_common, p_unique = create_pext(cur_params,
cur_params['tstop'])

Expand All @@ -293,10 +297,12 @@ def instantiate_feeds(self):
gid=gid)
event_times.append(feed.event_times)
event_times_per_source.update({src_type: event_times})
self._tmp_feed_event_times.append(event_times_per_source.copy())

# list of dict of list of list
self.trial_event_times.append(event_times_per_source.copy())

def _update_gid_ranges(self):
"""Creates gid dict from scratch every time called.
"""Creates gid ranges from scratch every time called.
Any method that adds real or artificial cells to the network must
call this to update the list of GIDs. Note that it's based on the
Expand Down
52 changes: 18 additions & 34 deletions hnn_core/network_builder.py
Expand Up @@ -266,11 +266,6 @@ def __init__(self, net, trial_idx=0):
self.net = net
self.trial_idx = trial_idx

# Create external feed param dictionaries to reflect possible
# param['prng_*'] value updates (different trial, different seed)
self._p_common, self._p_unique = create_pext(net.params,
net.params['tstop'])

# When computing the network dynamics in parallel, the nodes of the
# network (real and artificial cells) potentially get distributed
# on different host machines/threads. NetworkBuilder._gid_assign
Expand Down Expand Up @@ -325,7 +320,6 @@ def _build(self):

self._gid_assign()
self._create_cells_and_feeds(threshold=self.net.params['threshold'])
self._gids_to_parallel_context(threshold=self.net.params['threshold'])

self.state_init()
self._parnet_connect()
Expand Down Expand Up @@ -370,15 +364,15 @@ def _gid_assign(self):
# now to do the cell-specific external input gids on the same proc
# these are guaranteed to exist because all of
# these inputs were created for each cell
# XXX get list of all NetworkDrives that contact this cell, and
# get list of all NetworkDrives that contact this cell, and
# make sure the corresponding _ArtificialCell gids are associated
# with the current node/rank
for key in self._p_unique.keys():
for key in self.net._p_unique.keys():
gid_input = gid + self.net.gid_ranges[key][0]
_PC.set_gid2node(gid_input, rank)
self._gid_list.append(gid_input)

for gid_base in range(rank, self.net.n_common_feeds, nhosts):
for gid_base in range(rank, self.net._n_common_feeds, nhosts):
# shift the gid_base to the common gid
gid = gid_base + self.net.gid_ranges['common'][0]
# set as usual
Expand All @@ -391,7 +385,7 @@ def _create_cells_and_feeds(self, threshold):
"""Parallel create cells AND external inputs (feeds)
NB: _Cell.__init__ calls h.Section -> non-picklable!
NB: _ArtificialCell.__init__ calls h.XXX -> non-picklable!
NB: _ArtificialCell.__init__ calls h.*** -> non-picklable!
These feeds are spike SOURCES but cells are also targets.
External inputs are not targets.
Expand All @@ -403,9 +397,8 @@ def _create_cells_and_feeds(self, threshold):
# have to loop over self._gid_list, since this is what we got
# on this rank (MPI)

# XXX need mechanism for Builder to keep track of which trial it's on
this_trial_event_times =\
self.net._tmp_feed_event_times[self.trial_idx]
# mechanism for Builder to keep track of which trial it's on
this_trial_event_times = self.net.trial_event_times[self.trial_idx]

for gid in self._gid_list:
src_type, src_pos, is_cell = self.net._get_src_type_and_pos(gid)
Expand All @@ -415,40 +408,30 @@ def _create_cells_and_feeds(self, threshold):
# create cells based on loc property
if src_type in ('L2_pyramidal', 'L5_pyramidal'):
PyramidalCell = type2class[src_type]
# XXX Why doesn't a _Cell have a .threshold? Would make a
# lot of sense to include it, as _ArtificialCells do.
cell = PyramidalCell(src_pos, override_params=None,
gid=gid)
else:
BasketCell = type2class[src_type]
cell = BasketCell(src_pos, gid=gid)

# this calls seems to belong in init of a _Cell (w/threshold)?
nrn_netcon = cell.setup_source_netcon(threshold)
_PC.cell(cell.gid, nrn_netcon)
self.cells.append(cell)

# external inputs are special types of artificial-cells
# 'common': all cells impacted with identical TIMING of spike
# events. NB: cell types can still have different weights for
# how such 'common' spikes influence them
else:
# XXX figure out how to index the Nth spike in each list
gid_idx = gid - self.net.gid_ranges[src_type][0]
et = this_trial_event_times[src_type][gid_idx]
feed_cell = _ArtificialCell(et, threshold, gid=gid)
_PC.cell(feed_cell.gid, feed_cell.nrn_netcon)
self._feed_cells.append(feed_cell)

def _gids_to_parallel_context(self, threshold):

for cell in self.cells:
src_type = self.net._get_src_type_and_pos(cell.gid)[0]
# check existence of gid with Neuron
if not _PC.gid_exists(cell.gid):
msg = ('Source of type %s with ID %d does not exists in '
'Network' % (src_type, cell.gid))
raise RuntimeError(msg)

nrn_netcon = cell.setup_source_netcon(threshold)
_PC.cell(cell.gid, nrn_netcon)
# Then loop over _ArtificialCell's
for feed_cell in self._feed_cells:
_PC.cell(feed_cell.gid, feed_cell.nrn_netcon)

def _connect_celltypes(self, src_type, target_type, loc,
receptor, nc_dict, unique=False,
allow_autapses=True):
Expand Down Expand Up @@ -480,8 +463,9 @@ def _connect_celltypes(self, src_type, target_type, loc,
connection_name = f'{src_type}_{target_type}_{receptor}'
if connection_name not in self.ncs:
self.ncs[connection_name] = list()
# XXX: self.cells and _gid_list are not same length
# ideally self.cells should be a dict of list

assert len(self.cells) == len(self._gid_list) - len(self._feed_cells)
# NB this assumes that REAL cells are first in the _gid_list
for gid_target, target_cell in zip(self._gid_list, self.cells):
is_target_gid = (gid_target in
self.net.gid_ranges[_long_name(target_type)])
Expand Down Expand Up @@ -592,7 +576,7 @@ def _parnet_connect(self):
nc_dict)

# common feed -> xx
for p_common in self._p_common:
for p_common in self.net._p_common:
for target_cell_type in ['L2Basket', 'L5Basket', 'L5Pyr', 'L2Pyr']:
if (target_cell_type == 'L5Basket' and
p_common['loc'] == 'distal'):
Expand All @@ -609,7 +593,7 @@ def _parnet_connect(self):
nc_dict)

# unique feed -> xx
p_unique = self._p_unique
p_unique = self.net._p_unique
for src_cell_type in p_unique:

p_src = p_unique[src_cell_type]
Expand Down
11 changes: 2 additions & 9 deletions hnn_core/parallel_backends.py
Expand Up @@ -18,7 +18,7 @@
_BACKEND = None


def _clone_and_simulate(net, trial_idx, prng_seedcore_initial):
def _clone_and_simulate(net, trial_idx):
"""Run a simulation including building the network
This is used by both backends. MPIBackend calls this in mpi_child.py, once
Expand All @@ -30,11 +30,6 @@ def _clone_and_simulate(net, trial_idx, prng_seedcore_initial):
from hnn_core.network_builder import NetworkBuilder
from hnn_core.network_builder import _simulate_single_trial

# XXX this should be built into NetworkBuilder
# update prng_seedcore params to provide jitter between trials
for param_key in prng_seedcore_initial.keys():
net.params[param_key] = prng_seedcore_initial[param_key] + trial_idx

neuron_net = NetworkBuilder(net, trial_idx=trial_idx)
dpl = _simulate_single_trial(neuron_net, trial_idx)

Expand Down Expand Up @@ -138,10 +133,8 @@ def simulate(self, net):
n_trials = net.params['N_trials']
dpls = []

prng_seedcore_initial = net.params['prng_*'].copy()
parallel, myfunc = self._parallel_func(_clone_and_simulate)
sim_data = parallel(myfunc(net, idx, prng_seedcore_initial)
for idx in range(n_trials))
sim_data = parallel(myfunc(net, idx) for idx in range(n_trials))

dpls = _gather_trial_data(sim_data, net, n_trials)
return dpls
Expand Down
4 changes: 2 additions & 2 deletions hnn_core/tests/test_feed.py
Expand Up @@ -22,15 +22,15 @@ def test_extfeed():
pytest.raises(ValueError, ExtFeed,
'ev', None, p_bogus, 0) # ambiguous

# XXX 'unique' external feeds are always created; why?
# 'unique' external feeds are always created
for feed_type in ['extpois', 'extgauss']:
feed = ExtFeed(feed_type=feed_type,
target_cell_type='L2_basket',
params=p_unique[feed_type],
gid=0)
print(feed) # test repr

# XXX but 'common' (rhythmic) feeds are not
# but 'common' (rhythmic) feeds are not
for ii in range(len(p_common)): # len == 0 for def. params
feed = ExtFeed(feed_type='common',
target_cell_type=None,
Expand Down
14 changes: 7 additions & 7 deletions hnn_core/tests/test_network.py
Expand Up @@ -53,13 +53,13 @@ def test_network():

# test that expected number of external driving events are created, and
# make sure the PRNGs are consistent.
assert isinstance(net._tmp_feed_event_times, list)
assert len(net._tmp_feed_event_times) == 1 # single trial simulated
assert len(net._tmp_feed_event_times[0]['common']) == n_common_sources
assert len(net._tmp_feed_event_times[0]['common'][0]) == 40 # 40 spikes
assert isinstance(net._tmp_feed_event_times[0]['evprox1'][0], list)
assert len(net._tmp_feed_event_times[0]['evprox1']) == net.n_cells
assert_allclose(net._tmp_feed_event_times[0]['evprox1'][0],
assert isinstance(net.trial_event_times, list)
assert len(net.trial_event_times) == 1 # single trial simulated
assert len(net.trial_event_times[0]['common']) == n_common_sources
assert len(net.trial_event_times[0]['common'][0]) == 40 # 40 spikes
assert isinstance(net.trial_event_times[0]['evprox1'][0], list)
assert len(net.trial_event_times[0]['evprox1']) == net.n_cells
assert_allclose(net.trial_event_times[0]['evprox1'][0],
[23.80641637082997], rtol=1e-12)

assert len(network_builder._feed_cells) == (n_evoked_sources +
Expand Down
2 changes: 1 addition & 1 deletion hnn_core/tests/test_parallel_backends.py
Expand Up @@ -32,7 +32,7 @@ def run_hnn_core(backend=None, n_procs=None, n_jobs=1, reduced=False):
net = Network(params)

# two trials simulated
assert len(net_reduced._tmp_feed_event_times) == params_reduced['N_trials']
assert len(net_reduced.trial_event_times) == params_reduced['N_trials']

if backend == 'mpi':
with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'):
Expand Down

0 comments on commit 5c8a800

Please sign in to comment.