Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ENH: Function to get spike rates #155

Merged
merged 20 commits into from Sep 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2636083
Function to get spike rates
jasmainak Aug 27, 2020
d7ec743
FIX: compute spike rate per trial and then average across trials
jasmainak Aug 27, 2020
4a29344
Rename/reorganize to compute spike rates for individual gids, add opt…
ntolley Sep 4, 2020
4f082a7
Add ValueError check and corresponding test for mean_type argument
ntolley Sep 7, 2020
48665d9
Add tstart and tstop attributes to Spikes class, update Network/Spike…
ntolley Sep 7, 2020
9821a09
Change variable name
ntolley Sep 11, 2020
760feb2
DOC: Fix tstart/tstop descriptions
ntolley Sep 11, 2020
b40cf73
Remove unecessary else clause from read_spikes()
ntolley Sep 11, 2020
2338feb
DOC: Add description clarifying tstart/tstop only necessary for legac…
ntolley Sep 12, 2020
9eba0c3
TST: Raise error if tstart/tstop present in file and user provides va…
ntolley Sep 12, 2020
7777891
TST: Add tests asserting correct calculation of spikes.mean_rates()
ntolley Sep 12, 2020
37f9ba7
DOC: Update whats_new.rst
ntolley Sep 17, 2020
343b68a
Remove tstart and tstop as attributes, instead pass as parameters to …
ntolley Sep 26, 2020
4c4c92f
TST: Add tests validating correct tstart/tstop entry
ntolley Sep 26, 2020
d695663
TST: Add tests asserting correct calculation of mean rates
ntolley Sep 26, 2020
4afc94a
Make trial and cell counts explicit for gid_spike_rate preallocation
ntolley Sep 26, 2020
7f53609
Minor changes to increase legibility, remove leftover tstart/tstop code
ntolley Sep 26, 2020
a180b7d
TST: Print invalid input during mean_type value error
ntolley Sep 26, 2020
8573140
Update example with mean spike rate calculation
ntolley Sep 26, 2020
f24257d
DOC: Fix hyperlink
ntolley Sep 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -29,6 +29,8 @@ Changelog

- Update plot_hist_input() to plot_spikes_hist() which can plot histogram of spikes for any cell type, by `Nick Tolley`_ in `#157 <https://github.com/jonescompneurolab/hnn-core/pull/157>`_

- Add function to compute mean spike rates with user specified calculation type, by `Nick Tolley`_ and `Mainak Jas`_ in `#155 <https://github.com/jonescompneurolab/hnn-core/pull/155>`_

Bug
~~~

Expand Down
13 changes: 13 additions & 0 deletions examples/plot_simulate_evoked.py
Expand Up @@ -66,6 +66,19 @@
spikes = read_spikes(op.join(tmp_dir_name, 'spk_*.txt'))
spikes.plot()

###############################################################################
# We can additionally calculate the mean spike rates for each cell class by
# specifying a time window with tstart and tstop.
all_rates = spikes.mean_rates(tstart=0, tstop=170, gid_dict=net.gid_dict,
mean_type='all')
trial_rates = spikes.mean_rates(tstart=0, tstop=170, gid_dict=net.gid_dict,
mean_type='trial')
print('Mean spike rates across trials:')
print(all_rates)
print('Mean spike rates for individual trials:')
print(trial_rates)


###############################################################################
# Now, let us try to make the exogenous driving inputs to the cells
# synchronous and see what happens
Expand Down
69 changes: 69 additions & 0 deletions hnn_core/network.py
Expand Up @@ -336,6 +336,9 @@ class Spikes(object):
plot(ax=None, show=True)
Plot and return a matplotlib Figure object showing the
aggregate network spiking activity according to cell type.
mean_rates(tstart, tstop, gid_dict, mean_type='all')
Calculate mean firing rate for each cell type. Specify
averaging method with mean_type argument.
write(fname)
Write spiking activity to a collection of spike trial files.
"""
Expand Down Expand Up @@ -431,6 +434,72 @@ def update_types(self, gid_dict):
spike_types += [list(spike_types_trial)]
self._types = spike_types

def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'):
"""Mean spike rates (Hz) by cell type.

Parameters
----------
tstart : int | float | None
Value defining the start time of all trials.
tstop : int | float | None
Value defining the stop time of all trials.
gid_dict : dict of lists or range objects
Dictionary with keys 'evprox1', 'evdist1' etc.
containing the range of Cell or input IDs of different
cell or input types.
mean_type : str
'all' : Average over trials and cells
Returns mean rate for cell types
'trial' : Average over cell types
Returns trial mean rate for cell types
'cell' : Average over individual cells
Returns trial mean rate for individual cells

Returns
-------
spike_rate : dict
Dictionary with keys 'L5_pyramidal', 'L5_basket', etc.
"""
cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket']
spike_rates = dict()

if mean_type not in ['all', 'trial', 'cell']:
raise ValueError("Invalid mean_type. Valid arguments include "
f"'all', 'trial', or 'cell'. Got {mean_type}")

# Validate tstart, tstop
if not isinstance(tstart, (int, float)) or not isinstance(
tstop, (int, float)):
raise ValueError('tstart and tstop must be of type int or float')
elif tstop <= tstart:
raise ValueError('tstop must be greater than tstart')

for cell_type in cell_types:
cell_type_gids = np.array(gid_dict[cell_type])
n_trials, n_cells = len(self._times), len(cell_type_gids)
gid_spike_rate = np.zeros((n_trials, n_cells))

trial_data = zip(self._types, self._gids)
for trial_idx, (spike_types, spike_gids) in enumerate(trial_data):
trial_type_mask = np.in1d(spike_types, cell_type)
gids, gid_counts = np.unique(np.array(
spike_gids)[trial_type_mask], return_counts=True)

gid_spike_rate[trial_idx, np.in1d(cell_type_gids, gids)] = (
gid_counts / (tstop - tstart)) * 1000

if mean_type == 'all':
spike_rates[cell_type] = np.mean(
gid_spike_rate.mean(axis=1))
if mean_type == 'trial':
spike_rates[cell_type] = np.mean(
gid_spike_rate, axis=1).tolist()
if mean_type == 'cell':
spike_rates[cell_type] = [gid_trial_rate.tolist()
for gid_trial_rate in gid_spike_rate]
ntolley marked this conversation as resolved.
Show resolved Hide resolved

return spike_rates

def plot(self, ax=None, show=True):
"""Plot the aggregate spiking activity according to cell type.

Expand Down
35 changes: 35 additions & 0 deletions hnn_core/tests/test_network.py
Expand Up @@ -88,6 +88,9 @@ def test_spikes(tmpdir):
spiketimes = [[2.3456, 7.89], [4.2812, 93.2]]
spikegids = [[1, 3], [5, 7]]
spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']]
tstart, tstop = 0.1, 98.4
gid_dict = {'L2_pyramidal': range(1, 2), 'L2_basket': range(3, 4),
'L5_pyramidal': range(5, 6), 'L5_basket': range(7, 8)}
spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes)
spikes.plot_hist(show=False)
spikes.write(tmpdir.join('spk_%d.txt'))
Expand All @@ -107,6 +110,8 @@ def test_spikes(tmpdir):
spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids,
types=spiketypes)

spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes)

with pytest.raises(TypeError, match="spike_types should be str, "
"list, dict, or None"):
spikes.plot_hist(spike_types=1, show=False)
Expand All @@ -123,6 +128,36 @@ def test_spikes(tmpdir):
with pytest.raises(ValueError, match="No input types found for ABC"):
spikes.plot_hist(spike_types='ABC', show=False)

with pytest.raises(ValueError, match="tstart and tstop must be of type "
"int or float"):
spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict=gid_dict)

with pytest.raises(ValueError, match="tstop must be greater than tstart"):
spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict=gid_dict)

with pytest.raises(ValueError, match="Invalid mean_type. Valid "
"arguments include 'all', 'trial', or 'cell'."):
spikes.mean_rates(tstart=tstart, tstop=tstop, gid_dict=gid_dict,
mean_type='ABC')

test_rate = (1 / (tstop - tstart)) * 1000

assert spikes.mean_rates(tstart, tstop, gid_dict) == {
'L5_pyramidal': test_rate / 2,
'L5_basket': test_rate / 2,
'L2_pyramidal': test_rate / 2,
'L2_basket': test_rate / 2}
assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='trial') == {
'L5_pyramidal': [0.0, test_rate],
'L5_basket': [0.0, test_rate],
'L2_pyramidal': [test_rate, 0.0],
'L2_basket': [test_rate, 0.0]}
assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='cell') == {
'L5_pyramidal': [[0.0], [test_rate]],
'L5_basket': [[0.0], [test_rate]],
'L2_pyramidal': [[test_rate], [0.0]],
'L2_basket': [[test_rate], [0.0]]}

# Write spike file with no 'types' column
# Check for gid_dict errors
for fname in sorted(glob(str(tmpdir.join('spk_*.txt')))):
Expand Down