Skip to content

Commit

Permalink
ENH: enable NEURON parallelism
Browse files Browse the repository at this point in the history
Use ParallelContext to split computation for each simulation among
processors on a system. This mode is not compatible with joblibs
embarrasingly parallel execution and requires MPI.
  • Loading branch information
Blake Caldwell committed Sep 23, 2019
1 parent 0aa973e commit 9e16101
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 51 deletions.
34 changes: 34 additions & 0 deletions examples/parallel_simulate_evoked.py
@@ -0,0 +1,34 @@
"""
===============
Simulate dipole
===============
This example demonstrates how to simulate a dipole for evoked-like
waveforms using HNN-core.
"""

# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr>
# Sam Neymotin <samnemo@gmail.com>

import os.path as op

###############################################################################
# Let us import hnn_core

import hnn_core
from hnn_core import simulate_dipole, Params, Network, shutdown

hnn_core_root = op.join(op.dirname(hnn_core.__file__), '..')

###############################################################################
# Then we read the parameters file
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = Params(params_fname)

###############################################################################
# Now let's simulate the dipole
# You can simulate multiple trials in parallel by using n_jobs > 1
net = Network(params)
dpls = simulate_dipole(net, n_jobs=1, n_trials=4)

shutdown()
1 change: 1 addition & 0 deletions hnn_core/__init__.py
Expand Up @@ -4,3 +4,4 @@
from .network import Network
from .pyramidal import L2Pyr, L5Pyr
from .basket import L2Basket, L5Basket
from .parallel import shutdown
72 changes: 46 additions & 26 deletions hnn_core/dipole.py
Expand Up @@ -22,16 +22,18 @@ def _clone_and_simulate(params, trial_idx):
if trial_idx != 0:
params['prng_*'] = trial_idx

net = Network(params, n_jobs=1)
net = Network(params)
net.build_in_neuron()

return _simulate_single_trial(net)

from neuron import h
net._create_all_src()
net.state_init()
net._parnet_connect()
# set to record spikes
net.spiketimes = h.Vector()
net.spikegids = h.Vector()
net._record_spikes()

def _simulate_with_parallel_context(net, trial_idx):

if trial_idx != 0:
net.params['prng_*'] = trial_idx

net.build_in_neuron()

return _simulate_single_trial(net)

Expand All @@ -43,10 +45,12 @@ def _simulate_single_trial(net):
h.load_file("stdrun.hoc")

# Now let's simulate the dipole

pc.barrier() # sync for output to screen
if rank == 0:
print("running on %d cores" % nhosts)

# global variables, should be node-independent
# create or reinitialize scalars in NEURON (hoc) context
h("dp_total_L2 = 0.")
h("dp_total_L5 = 0.")

Expand All @@ -55,15 +59,11 @@ def _simulate_single_trial(net):
h.dt = net.params['dt'] # simulation duration and time-step
h.celsius = net.params['celsius'] # 37.0 - set temperature

# We define the arrays (Vector in numpy) for recording the signals
t_vec = h.Vector()
t_vec.record(h._ref_t) # time recording
dp_rec_L2 = h.Vector()
dp_rec_L2.record(h._ref_dp_total_L2) # L2 dipole recording
dp_rec_L5 = h.Vector()
dp_rec_L5.record(h._ref_dp_total_L5) # L5 dipole recording

net.move_cells_to_pos() # position cells in 2D grid
# Connect NEURON scalar references to python vectors
# TODO: initialize with Vector(size) to avoid dynamic resizing
t_vec = h.Vector().record(h._ref_t) # time recording
dp_rec_L2 = h.Vector().record(h._ref_dp_total_L2) # L2 dipole recording
dp_rec_L5 = h.Vector().record(h._ref_dp_total_L5) # L5 dipole recording

# sets the default max solver step in ms (purposefully large)
pc.set_maxstep(10)
Expand All @@ -81,8 +81,10 @@ def prsimtime():
cvode.event(tt, prsimtime) # print time callbacks

h.fcurrent()
# set state variables if they have been changed since h.finitialize
h.frecord_init()

# initialization complete, but wait for all procs to start the solver
pc.barrier()

# actual simulation - run the solver
pc.psolve(h.tstop)

Expand All @@ -105,8 +107,6 @@ def prsimtime():
np.array(dp_rec_L2.to_python()),
np.array(dp_rec_L5.to_python())]

pc.gid_clear()
pc.done()
dpl = Dipole(np.array(t_vec.to_python()), dpl_data)
if rank == 0:
if net.params['save_dpl']:
Expand Down Expand Up @@ -137,9 +137,29 @@ def simulate_dipole(net, n_trials=1, n_jobs=1):
dpl: list | instance of Dipole
The dipole object or list of dipole objects if n_trials > 1
"""
parallel, myfunc = _parallel_func(_clone_and_simulate, n_jobs=n_jobs)
out = parallel(myfunc(net.params, idx) for idx in range(n_trials))
dpl, spiketimes, spikegids = zip(*out)

from .parallel import create_parallel_context, get_nhosts

if n_jobs > 1:
# check whether NEURON is using parallel nrniv processes
create_parallel_context(n_cores=1)
if get_nhosts() > 1:
print("Nested parallelism is not currently supported!\n" +
"Please choose embarassinly parallel jobs (n_jobs > 1)\n" +
"or multiple cores per simulation (with MPI)\n")
return None

parallel, myfunc = _parallel_func(_clone_and_simulate, n_jobs=n_jobs)
out = parallel(myfunc(net.params, idx) for idx in range(n_trials))
dpl, spiketimes, spikegids = zip(*out)
else:
create_parallel_context()

out = []
for idx in range(n_trials):
out.append(_simulate_with_parallel_context(net, idx))
dpl, spiketimes, spikegids = zip(*out)

net.spiketimes = spiketimes
net.spikegids = spikegids
return dpl
Expand Down
83 changes: 69 additions & 14 deletions hnn_core/network.py
Expand Up @@ -2,12 +2,11 @@

# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr>
# Sam Neymotin <samnemo@gmail.com>
# Blake Caldwell <blake_caldwell@brown.edu>

import itertools as it
import numpy as np

from neuron import h

from .feed import ExtFeed
from .pyramidal import L2Pyr, L5Pyr
from .basket import L2Basket, L5Basket
Expand All @@ -21,8 +20,6 @@ class Network(object):
----------
params : dict
The parameters
n_jobs : int
The number of jobs to run in parallel
Attributes
----------
Expand All @@ -44,10 +41,7 @@ class Network(object):
The list contains the cell IDs of neurons that spiked.
"""

def __init__(self, params, n_jobs=1):
from .parallel import create_parallel_context
# setup simulation (ParallelContext)
create_parallel_context(n_jobs=n_jobs)
def __init__(self, params):

# set the params internally for this net
# better than passing it around like ...
Expand All @@ -59,11 +53,7 @@ def __init__(self, params, n_jobs=1):

self.N_t = np.arange(0., self.params['tstop'],
self.params['dt']).size + 1
# Create a h.Vector() with size 1xself.N_t, zero'd
self.current = {
'L5Pyr_soma': h.Vector(self.N_t, 0),
'L2Pyr_soma': h.Vector(self.N_t, 0),
}

# int variables for grid of pyramidal cells (for now in both L2 and L5)
self.gridpyr = {
'x': self.params['N_pyr_x'],
Expand Down Expand Up @@ -106,7 +96,6 @@ def __init__(self, params, n_jobs=1):
# assign gid to hosts, creates list of gids for this node in _gid_list
# _gid_list length is number of cells assigned to this id()
self._gid_list = []
self._gid_assign()
# create cells (and create self.origin in create_cells_pyr())
self.cells = []
self.extinput_list = []
Expand Down Expand Up @@ -388,6 +377,8 @@ def aggregate_currents(self):

def state_init(self):
"""Initializes the state closer to baseline."""
from neuron import h

for cell in self.cells:
seclist = h.SectionList()
seclist.wholetree(sec=cell.soma)
Expand Down Expand Up @@ -488,3 +479,67 @@ def plot_spikes(self, ax=None, show=True):
if show:
plt.show()
return ax.get_figure()

def build_in_neuron(self):
"""
This function must be called before Network can be used for sims
"""

from neuron import h
from .parallel import create_parallel_context, set_current_net

# make sure ParallelContext has been created (needed for joblibs)
create_parallel_context()

set_current_net(self)

self._gid_assign()
# Create a h.Vector() with size 1xself.N_t, zero'd
self.current = {
'L5Pyr_soma': h.Vector(self.N_t, 0),
'L2Pyr_soma': h.Vector(self.N_t, 0),
}
self._create_all_src()
self.state_init()
self._parnet_connect()
# set to record spikes
self.spiketimes = h.Vector()
self.spikegids = h.Vector()
self._record_spikes()
# position cells in 2D grid
self.move_cells_to_pos()

def clear_neuron_objects(self):
"""
Clear up NEURON internal gid information.
Note: This function must be called from the context of the
Network instance that ran build_in_neuron. This is a bug or
peculiarity of NEURON. If this function is called from a different
context, then the next simulation will run very slow because nrniv
workers are still going for the old simulation. If pc.gid_clear is
called from the right context, then those workers can exit.
"""
from .parallel import pc

pc.gid_clear()

# dereference cell and NetConn objects
for gid, cell in zip(self._gid_list, self.cells):
# only work on cells on this node
if pc.gid_exists(gid):
for name_src in ['L2Pyr', 'L2Basket', 'L5Pyr', 'L5Basket',
'extinput', 'extgauss', 'extpois', 'ev']:
for nc in getattr(cell, 'ncfrom_%s' % name_src):
if nc.valid():
# delete NEURON cell object
cell_obj1 = nc.precell(gid)
if cell_obj1 is not None:
del cell_obj1
cell_obj2 = nc.postcell(gid)
if cell_obj2 is not None:
del cell_obj2
del nc

self._gid_list = []
self.cells = []
62 changes: 51 additions & 11 deletions hnn_core/parallel.py
Expand Up @@ -7,20 +7,60 @@

from neuron import h

rank = 0
nhosts = 1
pc = h.ParallelContext(nhosts)
pc.done()
rank = int(pc.id())
cvode = h.CVode()
pc = None
last_network = None


def create_parallel_context(n_jobs=1):
"""Create parallel context."""
rank = int(pc.id()) # rank or node number (0 will be the master)
def shutdown():
pc.done()
h.quit()

if rank == 0:
pc.gid_clear()

def get_nhosts():
return nhosts


def create_parallel_context(n_cores=None):
"""Create parallel context.
Parameters
----------
n_cores: int | None
Number of processors to use for a simulation. A value of None will
allow NEURON to use all available processors.
"""

global rank, nhosts, cvode, pc, last_network

if pc is None:
if n_cores is None:
# MPI: Initialize the ParallelContext class
pc = h.ParallelContext()
else:
pc = h.ParallelContext(n_cores)

nhosts = int(pc.nhost()) # Find number of hosts
rank = int(pc.id()) # rank or node number (0 will be the master)
cvode = h.CVode()

# be explicit about using fixed step integration
cvode.active(0)

# use cache_efficient mode for allocating elements in contiguous order
# cvode.cache_efficient(1)
else:
# ParallelContext() has already been called. Don't start more workers.
# Just tell old nrniv workers to quit.
pc.done()


def set_current_net(net):
global last_network

if last_network is not None:
last_network.clear_neuron_objects()

net.clear_neuron_objects()
last_network = net


def _parallel_func(func, n_jobs):
Expand Down

0 comments on commit 9e16101

Please sign in to comment.