New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
postprocessing #188
postprocessing #188
Changes from 16 commits
b9f1172
18dcb56
82d8556
5887f9c
cf9c7b7
8f97112
691301d
a7f634f
b6bd7a5
8e1e80c
da7a7bc
cfaf74e
9619ea8
f2e82e4
e516d34
e05356a
b97f452
3ad7abd
b304836
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -18,7 +18,7 @@ def _hammfilt(x, winsz): | |||||
|
||||||
|
||||||
def simulate_dipole(net, n_trials=None, record_vsoma=False, | ||||||
record_isoma=False): | ||||||
record_isoma=False, postproc=True): | ||||||
"""Simulate a dipole given the experiment parameters. | ||||||
|
||||||
Parameters | ||||||
|
@@ -33,6 +33,8 @@ def simulate_dipole(net, n_trials=None, record_vsoma=False, | |||||
Option to record somatic voltages from cells | ||||||
record_isoma : bool | ||||||
Option to record somatic currents from cells | ||||||
postproc : bool | ||||||
If False, no postprocessing applied to the dipole | ||||||
|
||||||
Returns | ||||||
------- | ||||||
|
@@ -67,7 +69,7 @@ def simulate_dipole(net, n_trials=None, record_vsoma=False, | |||||
raise TypeError("record_isoma must be bool, got %s" | ||||||
% type(record_isoma).__name__) | ||||||
|
||||||
dpls = _BACKEND.simulate(net) | ||||||
dpls = _BACKEND.simulate(net, postproc) | ||||||
|
||||||
return dpls | ||||||
|
||||||
|
@@ -165,6 +167,25 @@ def __init__(self, times, data, nave=1): # noqa: D102 | |||||
self.data = {'agg': data[:, 0], 'L2': data[:, 1], 'L5': data[:, 2]} | ||||||
self.nave = nave | ||||||
|
||||||
def post_proc(self, N_pyr_x, N_pyr_y, winsz, fctr): | ||||||
jasmainak marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" Apply baseline, unit conversion, scaling and smoothing | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
N_pyr_x : int | ||||||
Number of Pyramidal cells in x direction | ||||||
N_pyr_y : int | ||||||
Number of Pyramidal cells in y direction | ||||||
winsz : int | ||||||
Smoothing window | ||||||
fctr : int | ||||||
Scaling factor | ||||||
""" | ||||||
self.baseline_renormalize(N_pyr_x, N_pyr_y) | ||||||
self.convert_fAm_to_nAm() | ||||||
self.scale(fctr) | ||||||
self.smooth(winsz) | ||||||
|
||||||
def convert_fAm_to_nAm(self): | ||||||
""" must be run after baseline_renormalization() | ||||||
""" | ||||||
|
@@ -205,21 +226,21 @@ def plot(self, ax=None, layer='agg', show=True): | |||||
""" | ||||||
return plot_dipole(dpl=self, ax=ax, layer=layer, show=show) | ||||||
|
||||||
def baseline_renormalize(self, params): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lovely! |
||||||
def baseline_renormalize(self, N_pyr_x, N_pyr_y): | ||||||
"""Only baseline renormalize if the units are fAm. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
params : dict | ||||||
The parameters | ||||||
N_pyr_x : int | ||||||
Nr of cells (x) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add this change in next commit? |
||||||
N_pyr_y : int | ||||||
Nr of cells (y) | ||||||
""" | ||||||
if self.units != 'fAm': | ||||||
print("Warning, no dipole renormalization done because units" | ||||||
" were in %s" % (self.units)) | ||||||
return | ||||||
|
||||||
N_pyr_x = params['N_pyr_x'] | ||||||
N_pyr_y = params['N_pyr_y'] | ||||||
# N_pyr cells in grid. This is PER LAYER | ||||||
N_pyr = N_pyr_x * N_pyr_y | ||||||
# dipole offset calculation: increasing number of pyr | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,11 @@ | |
from typing import Dict, Tuple | ||
import pytest | ||
|
||
import os.path as op | ||
import hnn_core | ||
from hnn_core import read_params, Network, simulate_dipole | ||
from hnn_core import MPIBackend, JoblibBackend | ||
|
||
# store history of failures per test class name and per index in parametrize | ||
# (if parametrize used) | ||
_test_failed_incremental: Dict[str, Dict[Tuple[int, ...], str]] = {} | ||
|
@@ -63,3 +68,42 @@ def pytest_runtest_setup(item): | |
# and test name | ||
if test_name is not None: | ||
pytest.xfail("previous test failed ({})".format(test_name)) | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def run_hnn_core(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about putting "fixture" in this function name? Otherwise it looks like a mysterious parameter to other testing functions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 @kohl-carmen can you make the changes yourself? |
||
def _run_hnn_core(backend=None, n_procs=None, n_jobs=1, reduced=False, | ||
record_vsoma=False, record_isoma=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kohl-carmen if you add a |
||
hnn_core_root = op.dirname(hnn_core.__file__) | ||
|
||
# default params | ||
params_fname = op.join(hnn_core_root, 'param', 'default.json') | ||
params = read_params(params_fname) | ||
|
||
if reduced: | ||
params.update({'N_pyr_x': 3, | ||
'N_pyr_y': 3, | ||
'tstop': 25, | ||
't_evprox_1': 5, | ||
't_evdist_1': 10, | ||
't_evprox_2': 20, | ||
'N_trials': 2}) | ||
net = Network(params) | ||
|
||
# number of trials simulated | ||
assert len(net.trial_event_times) == params['N_trials'] | ||
|
||
if backend == 'mpi': | ||
with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'): | ||
dpls = simulate_dipole(net, record_vsoma=record_isoma, | ||
record_isoma=record_vsoma) | ||
elif backend == 'joblib': | ||
with JoblibBackend(n_jobs=n_jobs): | ||
dpls = simulate_dipole(net, record_vsoma=record_isoma, | ||
record_isoma=record_vsoma) | ||
else: | ||
dpls = simulate_dipole(net, record_vsoma=record_isoma, | ||
record_isoma=record_vsoma) | ||
|
||
return dpls, net | ||
return _run_hnn_core |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,10 @@ | |
import pytest | ||
|
||
import hnn_core | ||
from hnn_core import read_params, read_dipole, average_dipoles, viz, Network | ||
from hnn_core import JoblibBackend, MPIBackend | ||
from hnn_core import read_params, read_dipole, average_dipoles, Network | ||
from hnn_core.viz import plot_dipole | ||
from hnn_core.dipole import Dipole, simulate_dipole | ||
from hnn_core.parallel_backends import requires_mpi4py | ||
|
||
matplotlib.use('agg') | ||
|
||
|
@@ -22,12 +23,12 @@ def test_dipole(tmpdir): | |
times = np.random.random(6000) | ||
data = np.random.random((6000, 3)) | ||
dipole = Dipole(times, data) | ||
dipole.baseline_renormalize(params) | ||
dipole.baseline_renormalize(params['N_pyr_x'], params['N_pyr_y']) | ||
dipole.convert_fAm_to_nAm() | ||
dipole.scale(params['dipole_scalefctr']) | ||
dipole.smooth(params['dipole_smooth_win'] / params['dt']) | ||
dipole.plot(show=False) | ||
viz.plot_dipole([dipole, dipole], show=False) | ||
plot_dipole([dipole, dipole], show=False) | ||
dipole.write(dpl_out_fname) | ||
dipole_read = read_dipole(dpl_out_fname) | ||
assert_allclose(dipole_read.times, dipole.times, rtol=0, atol=0.00051) | ||
|
@@ -45,6 +46,29 @@ def test_dipole(tmpdir): | |
"average of 2 trials"): | ||
dipole_avg = average_dipoles([dipole_avg, dipole_read]) | ||
|
||
# test postproc | ||
hnn_core_root = op.dirname(hnn_core.__file__) | ||
params_fname = op.join(hnn_core_root, 'param', 'default.json') | ||
params = read_params(params_fname) | ||
params_reduced = params.copy() | ||
params_reduced.update({'N_pyr_x': 3, | ||
'N_pyr_y': 3, | ||
'tstop': 25, | ||
't_evprox_1': 5, | ||
't_evdist_1': 10, | ||
't_evprox_2': 20, | ||
'N_trials': 2}) | ||
net = Network(params_reduced) | ||
dpls_raw = simulate_dipole(net, postproc=False) | ||
dpls = simulate_dipole(net, postproc=True) | ||
rythorpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with pytest.raises(AssertionError): | ||
assert_allclose(dpls[0].data['agg'], dpls_raw[0].data['agg']) | ||
jasmainak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dpls_raw[0].post_proc(params_reduced['N_pyr_x'], params_reduced['N_pyr_y'], | ||
params_reduced['dipole_smooth_win'] / | ||
params_reduced['dt'], | ||
params_reduced['dipole_scalefctr']) | ||
assert_allclose(dpls_raw[0].data['agg'], dpls[0].data['agg']) | ||
rythorpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_dipole_simulation(): | ||
"""Test data produced from simulate_dipole() call.""" | ||
|
@@ -65,36 +89,40 @@ def test_dipole_simulation(): | |
with pytest.raises(TypeError, match="record_isoma must be bool, got int"): | ||
simulate_dipole(net, n_trials=1, record_vsoma=False, record_isoma=0) | ||
|
||
trial, n_trials, gid = 0, 2, 7 | ||
n_times = np.arange(0., params['tstop'] + params['dt'], params['dt']).size | ||
mpi_net = Network(params) | ||
with MPIBackend(n_procs=None, mpi_cmd='mpiexec'): | ||
simulate_dipole(mpi_net, n_trials=n_trials, record_vsoma=True, | ||
record_isoma=True) | ||
assert len(mpi_net.cell_response.vsoma) == n_trials | ||
assert len(mpi_net.cell_response.isoma) == n_trials | ||
assert len(mpi_net.cell_response.vsoma[trial][gid]) == n_times | ||
assert len(mpi_net.cell_response.isoma[ | ||
trial][gid]['soma_gabaa']) == n_times | ||
|
||
joblib_net = Network(params) | ||
with JoblibBackend(n_jobs=1): | ||
simulate_dipole(joblib_net, n_trials=n_trials, record_vsoma=True, | ||
record_isoma=True) | ||
assert len(joblib_net.cell_response.vsoma) == n_trials | ||
assert len(joblib_net.cell_response.isoma) == n_trials | ||
assert len(joblib_net.cell_response.vsoma[trial][gid]) == n_times | ||
assert len(joblib_net.cell_response.isoma[ | ||
trial][gid]['soma_gabaa']) == n_times | ||
|
||
@requires_mpi4py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, I didn't even realize we had a MPIBackend test in another file besides parallel_backends.py. Thanks @jasmainak for making the testing changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah those would be my changes, thanks for fixing my ad-hoc MPI tests @jasmainak. |
||
def test_cell_response_backends(run_hnn_core): | ||
"""Test cell_response outputs across backends.""" | ||
|
||
# reduced simulation has n_trials=2 | ||
trial_idx, n_trials, gid = 0, 2, 7 | ||
_, joblib_net = run_hnn_core(backend='joblib', n_jobs=1, reduced=True, | ||
record_isoma=True, record_vsoma=True) | ||
_, mpi_net = run_hnn_core(backend='mpi', n_jobs=1, reduced=True, | ||
blakecaldwell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
record_isoma=True, record_vsoma=True) | ||
n_times = len(joblib_net.cell_response.times) | ||
|
||
assert len(joblib_net.cell_response.vsoma) == n_trials | ||
assert len(joblib_net.cell_response.isoma) == n_trials | ||
assert len(joblib_net.cell_response.vsoma[trial_idx][gid]) == n_times | ||
assert len(joblib_net.cell_response.isoma[ | ||
trial_idx][gid]['soma_gabaa']) == n_times | ||
|
||
assert len(mpi_net.cell_response.vsoma) == n_trials | ||
assert len(mpi_net.cell_response.isoma) == n_trials | ||
assert len(mpi_net.cell_response.vsoma[trial_idx][gid]) == n_times | ||
assert len(mpi_net.cell_response.isoma[ | ||
trial_idx][gid]['soma_gabaa']) == n_times | ||
assert mpi_net.cell_response.vsoma == joblib_net.cell_response.vsoma | ||
assert mpi_net.cell_response.isoma == joblib_net.cell_response.isoma | ||
|
||
# Test if spike time falls within depolarization window above v_thresh | ||
v_thresh = 0.0 | ||
times = np.array(joblib_net.cell_response.times) | ||
spike_times = np.array(joblib_net.cell_response.spike_times[trial]) | ||
spike_gids = np.array(joblib_net.cell_response.spike_gids[trial]) | ||
vsoma = np.array(joblib_net.cell_response.vsoma[trial][gid]) | ||
spike_times = np.array(joblib_net.cell_response.spike_times[trial_idx]) | ||
spike_gids = np.array(joblib_net.cell_response.spike_gids[trial_idx]) | ||
vsoma = np.array(joblib_net.cell_response.vsoma[trial_idx][gid]) | ||
|
||
v_mask = vsoma > v_thresh | ||
assert np.all([spike_times[spike_gids == gid] > times[v_mask][0], | ||
spike_times[spike_gids == gid] < times[v_mask][-1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the other conflicting spot. I think post-processing should be listed first so the docstring would become
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with
postproc
being listed last as long as it matches the function signature. You have less chance of breaking existing user code when you add an argument to the end of the function. E.g., if someone did:now if you add
postproc
in the middle, it will get passed asrecord_vsoma
when the user intended otherwise