Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

postprocessing #188

Merged
merged 19 commits into from Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/plot_simulate_evoked.py
Expand Up @@ -46,7 +46,7 @@
from hnn_core import JoblibBackend

with JoblibBackend(n_jobs=1):
dpls = simulate_dipole(net, n_trials=2)
dpls = simulate_dipole(net, n_trials=2, postproc=True)

###############################################################################
# and then plot it
Expand Down
35 changes: 28 additions & 7 deletions hnn_core/dipole.py
Expand Up @@ -18,7 +18,7 @@ def _hammfilt(x, winsz):


def simulate_dipole(net, n_trials=None, record_vsoma=False,
record_isoma=False):
record_isoma=False, postproc=True):
"""Simulate a dipole given the experiment parameters.

Parameters
Expand All @@ -33,6 +33,8 @@ def simulate_dipole(net, n_trials=None, record_vsoma=False,
Option to record somatic voltages from cells
record_isoma : bool
Option to record somatic currents from cells
postproc : bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the other conflicting spot. I think post-processing should be listed first so the docstring would become

postproc : bool
     If False, no postprocessing applied to the dipole
record_vsoma : bool
     Option to record somatic voltages from cells
record_isoma : bool
     Option to record somatic currents from cells

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with postproc being listed last as long as it matches the function signature. You have less chance of breaking existing user code when you add an argument to the end of the function. E.g., if someone did:

record_vsoma = True
simulate_dipole(net, n_trials, record_vsoma, record_isoma)

now if you add postproc in the middle, it will get passed as record_vsoma when the user intended otherwise

If False, no postprocessing applied to the dipole

Returns
-------
Expand Down Expand Up @@ -67,7 +69,7 @@ def simulate_dipole(net, n_trials=None, record_vsoma=False,
raise TypeError("record_isoma must be bool, got %s"
% type(record_isoma).__name__)

dpls = _BACKEND.simulate(net)
dpls = _BACKEND.simulate(net, postproc)

return dpls

Expand Down Expand Up @@ -165,6 +167,25 @@ def __init__(self, times, data, nave=1): # noqa: D102
self.data = {'agg': data[:, 0], 'L2': data[:, 1], 'L5': data[:, 2]}
self.nave = nave

def post_proc(self, N_pyr_x, N_pyr_y, winsz, fctr):
jasmainak marked this conversation as resolved.
Show resolved Hide resolved
""" Apply baseline, unit conversion, scaling and smoothing

Parameters
----------
N_pyr_x : int
Number of Pyramidal cells in x direction
N_pyr_y : int
Number of Pyramidal cells in y direction
winsz : int
Smoothing window
fctr : int
Scaling factor
"""
self.baseline_renormalize(N_pyr_x, N_pyr_y)
self.convert_fAm_to_nAm()
self.scale(fctr)
self.smooth(winsz)

def convert_fAm_to_nAm(self):
""" must be run after baseline_renormalization()
"""
Expand Down Expand Up @@ -205,21 +226,21 @@ def plot(self, ax=None, layer='agg', show=True):
"""
return plot_dipole(dpl=self, ax=ax, layer=layer, show=show)

def baseline_renormalize(self, params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lovely!

def baseline_renormalize(self, N_pyr_x, N_pyr_y):
"""Only baseline renormalize if the units are fAm.

Parameters
----------
params : dict
The parameters
N_pyr_x : int
Nr of cells (x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Nr of cells (x)
Number of Pyramidal cells in x direction

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this change in next commit?

N_pyr_y : int
Nr of cells (y)
"""
if self.units != 'fAm':
print("Warning, no dipole renormalization done because units"
" were in %s" % (self.units))
return

N_pyr_x = params['N_pyr_x']
N_pyr_y = params['N_pyr_y']
# N_pyr cells in grid. This is PER LAYER
N_pyr = N_pyr_x * N_pyr_y
# dipole offset calculation: increasing number of pyr
Expand Down
8 changes: 0 additions & 8 deletions hnn_core/network_builder.py
Expand Up @@ -105,14 +105,6 @@ def simulation_time():
np.array(neuron_net.dipoles['L5_pyramidal'].to_python())]

dpl = Dipole(times, dpl_data)
if rank == 0:
if neuron_net.net.params['save_dpl']:
dpl.write('rawdpl.txt')

dpl.baseline_renormalize(neuron_net.net.params)
dpl.convert_fAm_to_nAm()
dpl.scale(neuron_net.net.params['dipole_scalefctr'])
dpl.smooth(neuron_net.net.params['dipole_smooth_win'] / h.dt)

return dpl

Expand Down
42 changes: 36 additions & 6 deletions hnn_core/parallel_backends.py
Expand Up @@ -15,6 +15,7 @@
import binascii
from time import sleep


_BACKEND = None


Expand All @@ -38,7 +39,7 @@ def _clone_and_simulate(net, trial_idx):
return dpl, spikedata


def _gather_trial_data(sim_data, net, n_trials):
def _gather_trial_data(sim_data, net, n_trials, postproc):
"""Arrange data by trial

To be called after simulate(). Returns list of Dipoles, one for each trial,
Expand All @@ -56,6 +57,13 @@ def _gather_trial_data(sim_data, net, n_trials):
net.cell_response._vsoma.append(spikedata[3])
net.cell_response._isoma.append(spikedata[4])

if postproc:
N_pyr_x = net.params['N_pyr_x']
N_pyr_y = net.params['N_pyr_y']
winsz = net.params['dipole_smooth_win'] / net.params['dt']
fctr = net.params['dipole_scalefctr']
dpls[-1].post_proc(N_pyr_x, N_pyr_y, winsz, fctr)

return dpls


Expand All @@ -70,6 +78,23 @@ def _read_all_bytes(fd, chunk_size=4096):
return all_data


def requires_mpi4py(function):
"""Decorator for testing functions that require MPI."""
import pytest

try:
import mpi4py
assert hasattr(mpi4py, '__version__')
skip = False
except (ImportError, ModuleNotFoundError) as err:
if "TRAVIS_OS_NAME" not in os.environ:
skip = True
else:
raise ImportError(err)
reason = 'mpi4py not available'
return pytest.mark.skipif(skip, reason=reason)(function)


class JoblibBackend(object):
"""The JoblibBackend class.

Expand Down Expand Up @@ -117,14 +142,16 @@ def __exit__(self, type, value, traceback):

_BACKEND = self._old_backend

def simulate(self, net):
def simulate(self, net, postproc=True):
"""Simulate the HNN model

Parameters
----------
net : Network object
The Network object specifying how cells are
connected.
postproc : bool
If False, no postprocessing applied to the dipole

Returns
-------
Expand All @@ -138,7 +165,8 @@ def simulate(self, net):
parallel, myfunc = self._parallel_func(_clone_and_simulate)
sim_data = parallel(myfunc(net, idx) for idx in range(n_trials))

dpls = _gather_trial_data(sim_data, net, n_trials)
dpls = _gather_trial_data(sim_data, net, n_trials, postproc)

return dpls


Expand Down Expand Up @@ -323,14 +351,16 @@ def _process_child_data(self, data_bytes, data_len):
# unpickle the data
return pickle.loads(data_pickled)

def simulate(self, net):
def simulate(self, net, postproc=True):
"""Simulate the HNN model in parallel on all cores

Parameters
----------
net : Network object
The Network object specifying how cells are
connected.
postproc: bool
If False, no postprocessing applied to the dipole

Returns
-------
Expand All @@ -340,7 +370,7 @@ def simulate(self, net):

# just use the joblib backend for a single core
if self.n_procs == 1:
return JoblibBackend(n_jobs=1).simulate(net)
return JoblibBackend(n_jobs=1).simulate(net, postproc)
blakecaldwell marked this conversation as resolved.
Show resolved Hide resolved

n_trials = net.params['N_trials']
print("Running %d trials..." % (n_trials))
Expand Down Expand Up @@ -446,5 +476,5 @@ def simulate(self, net):

sim_data = self._process_child_data(self.proc_data_bytes, data_len)

dpls = _gather_trial_data(sim_data, net, n_trials)
dpls = _gather_trial_data(sim_data, net, n_trials, postproc)
return dpls
44 changes: 44 additions & 0 deletions hnn_core/tests/conftest.py
Expand Up @@ -6,6 +6,11 @@
from typing import Dict, Tuple
import pytest

import os.path as op
import hnn_core
from hnn_core import read_params, Network, simulate_dipole
from hnn_core import MPIBackend, JoblibBackend

# store history of failures per test class name and per index in parametrize
# (if parametrize used)
_test_failed_incremental: Dict[str, Dict[Tuple[int, ...], str]] = {}
Expand Down Expand Up @@ -63,3 +68,42 @@ def pytest_runtest_setup(item):
# and test name
if test_name is not None:
pytest.xfail("previous test failed ({})".format(test_name))


@pytest.fixture(scope='module')
def run_hnn_core():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about putting "fixture" in this function name? Otherwise it looks like a mysterious parameter to other testing functions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @kohl-carmen can you make the changes yourself?

def _run_hnn_core(backend=None, n_procs=None, n_jobs=1, reduced=False,
record_vsoma=False, record_isoma=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kohl-carmen if you add a postproc parameter here, then you can avoid duplication of code

hnn_core_root = op.dirname(hnn_core.__file__)

# default params
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)

if reduced:
params.update({'N_pyr_x': 3,
'N_pyr_y': 3,
'tstop': 25,
't_evprox_1': 5,
't_evdist_1': 10,
't_evprox_2': 20,
'N_trials': 2})
net = Network(params)

# number of trials simulated
assert len(net.trial_event_times) == params['N_trials']

if backend == 'mpi':
with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'):
dpls = simulate_dipole(net, record_vsoma=record_isoma,
record_isoma=record_vsoma)
elif backend == 'joblib':
with JoblibBackend(n_jobs=n_jobs):
dpls = simulate_dipole(net, record_vsoma=record_isoma,
record_isoma=record_vsoma)
else:
dpls = simulate_dipole(net, record_vsoma=record_isoma,
record_isoma=record_vsoma)

return dpls, net
return _run_hnn_core
84 changes: 56 additions & 28 deletions hnn_core/tests/test_dipole.py
Expand Up @@ -6,9 +6,10 @@
import pytest

import hnn_core
from hnn_core import read_params, read_dipole, average_dipoles, viz, Network
from hnn_core import JoblibBackend, MPIBackend
from hnn_core import read_params, read_dipole, average_dipoles, Network
from hnn_core.viz import plot_dipole
from hnn_core.dipole import Dipole, simulate_dipole
from hnn_core.parallel_backends import requires_mpi4py

matplotlib.use('agg')

Expand All @@ -22,12 +23,12 @@ def test_dipole(tmpdir):
times = np.random.random(6000)
data = np.random.random((6000, 3))
dipole = Dipole(times, data)
dipole.baseline_renormalize(params)
dipole.baseline_renormalize(params['N_pyr_x'], params['N_pyr_y'])
dipole.convert_fAm_to_nAm()
dipole.scale(params['dipole_scalefctr'])
dipole.smooth(params['dipole_smooth_win'] / params['dt'])
dipole.plot(show=False)
viz.plot_dipole([dipole, dipole], show=False)
plot_dipole([dipole, dipole], show=False)
dipole.write(dpl_out_fname)
dipole_read = read_dipole(dpl_out_fname)
assert_allclose(dipole_read.times, dipole.times, rtol=0, atol=0.00051)
Expand All @@ -45,6 +46,29 @@ def test_dipole(tmpdir):
"average of 2 trials"):
dipole_avg = average_dipoles([dipole_avg, dipole_read])

# test postproc
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
params_reduced = params.copy()
params_reduced.update({'N_pyr_x': 3,
'N_pyr_y': 3,
'tstop': 25,
't_evprox_1': 5,
't_evdist_1': 10,
't_evprox_2': 20,
'N_trials': 2})
net = Network(params_reduced)
dpls_raw = simulate_dipole(net, postproc=False)
dpls = simulate_dipole(net, postproc=True)
rythorpe marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(AssertionError):
assert_allclose(dpls[0].data['agg'], dpls_raw[0].data['agg'])
jasmainak marked this conversation as resolved.
Show resolved Hide resolved
dpls_raw[0].post_proc(params_reduced['N_pyr_x'], params_reduced['N_pyr_y'],
params_reduced['dipole_smooth_win'] /
params_reduced['dt'],
params_reduced['dipole_scalefctr'])
assert_allclose(dpls_raw[0].data['agg'], dpls[0].data['agg'])
rythorpe marked this conversation as resolved.
Show resolved Hide resolved


def test_dipole_simulation():
"""Test data produced from simulate_dipole() call."""
Expand All @@ -65,36 +89,40 @@ def test_dipole_simulation():
with pytest.raises(TypeError, match="record_isoma must be bool, got int"):
simulate_dipole(net, n_trials=1, record_vsoma=False, record_isoma=0)

trial, n_trials, gid = 0, 2, 7
n_times = np.arange(0., params['tstop'] + params['dt'], params['dt']).size
mpi_net = Network(params)
with MPIBackend(n_procs=None, mpi_cmd='mpiexec'):
simulate_dipole(mpi_net, n_trials=n_trials, record_vsoma=True,
record_isoma=True)
assert len(mpi_net.cell_response.vsoma) == n_trials
assert len(mpi_net.cell_response.isoma) == n_trials
assert len(mpi_net.cell_response.vsoma[trial][gid]) == n_times
assert len(mpi_net.cell_response.isoma[
trial][gid]['soma_gabaa']) == n_times

joblib_net = Network(params)
with JoblibBackend(n_jobs=1):
simulate_dipole(joblib_net, n_trials=n_trials, record_vsoma=True,
record_isoma=True)
assert len(joblib_net.cell_response.vsoma) == n_trials
assert len(joblib_net.cell_response.isoma) == n_trials
assert len(joblib_net.cell_response.vsoma[trial][gid]) == n_times
assert len(joblib_net.cell_response.isoma[
trial][gid]['soma_gabaa']) == n_times

@requires_mpi4py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I didn't even realize we had a MPIBackend test in another file besides parallel_backends.py. Thanks @jasmainak for making the testing changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah those would be my changes, thanks for fixing my ad-hoc MPI tests @jasmainak.

def test_cell_response_backends(run_hnn_core):
"""Test cell_response outputs across backends."""

# reduced simulation has n_trials=2
trial_idx, n_trials, gid = 0, 2, 7
_, joblib_net = run_hnn_core(backend='joblib', n_jobs=1, reduced=True,
record_isoma=True, record_vsoma=True)
_, mpi_net = run_hnn_core(backend='mpi', n_jobs=1, reduced=True,
blakecaldwell marked this conversation as resolved.
Show resolved Hide resolved
record_isoma=True, record_vsoma=True)
n_times = len(joblib_net.cell_response.times)

assert len(joblib_net.cell_response.vsoma) == n_trials
assert len(joblib_net.cell_response.isoma) == n_trials
assert len(joblib_net.cell_response.vsoma[trial_idx][gid]) == n_times
assert len(joblib_net.cell_response.isoma[
trial_idx][gid]['soma_gabaa']) == n_times

assert len(mpi_net.cell_response.vsoma) == n_trials
assert len(mpi_net.cell_response.isoma) == n_trials
assert len(mpi_net.cell_response.vsoma[trial_idx][gid]) == n_times
assert len(mpi_net.cell_response.isoma[
trial_idx][gid]['soma_gabaa']) == n_times
assert mpi_net.cell_response.vsoma == joblib_net.cell_response.vsoma
assert mpi_net.cell_response.isoma == joblib_net.cell_response.isoma

# Test if spike time falls within depolarization window above v_thresh
v_thresh = 0.0
times = np.array(joblib_net.cell_response.times)
spike_times = np.array(joblib_net.cell_response.spike_times[trial])
spike_gids = np.array(joblib_net.cell_response.spike_gids[trial])
vsoma = np.array(joblib_net.cell_response.vsoma[trial][gid])
spike_times = np.array(joblib_net.cell_response.spike_times[trial_idx])
spike_gids = np.array(joblib_net.cell_response.spike_gids[trial_idx])
vsoma = np.array(joblib_net.cell_response.vsoma[trial_idx][gid])

v_mask = vsoma > v_thresh
assert np.all([spike_times[spike_gids == gid] > times[v_mask][0],
spike_times[spike_gids == gid] < times[v_mask][-1]])