diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 154ac0bd8..904011536 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -29,6 +29,8 @@ Changelog - Update plot_hist_input() to plot_spikes_hist() which can plot histogram of spikes for any cell type, by `Nick Tolley`_ in `#157 `_ +- Add function to compute mean spike rates with user specified calculation type, by `Nick Tolley`_ and `Mainak Jas`_ in `#155 `_ + Bug ~~~ diff --git a/examples/plot_simulate_evoked.py b/examples/plot_simulate_evoked.py index befc6b522..82202d229 100644 --- a/examples/plot_simulate_evoked.py +++ b/examples/plot_simulate_evoked.py @@ -66,6 +66,19 @@ spikes = read_spikes(op.join(tmp_dir_name, 'spk_*.txt')) spikes.plot() +############################################################################### +# We can additionally calculate the mean spike rates for each cell class by +# specifying a time window with tstart and tstop. +all_rates = spikes.mean_rates(tstart=0, tstop=170, gid_dict=net.gid_dict, + mean_type='all') +trial_rates = spikes.mean_rates(tstart=0, tstop=170, gid_dict=net.gid_dict, + mean_type='trial') +print('Mean spike rates across trials:') +print(all_rates) +print('Mean spike rates for individual trials:') +print(trial_rates) + + ############################################################################### # Now, let us try to make the exogenous driving inputs to the cells # synchronous and see what happens diff --git a/hnn_core/network.py b/hnn_core/network.py index cb71dda25..a1ec6786c 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -336,6 +336,9 @@ class Spikes(object): plot(ax=None, show=True) Plot and return a matplotlib Figure object showing the aggregate network spiking activity according to cell type. + mean_rates(tstart, tstop, gid_dict, mean_type='all') + Calculate mean firing rate for each cell type. Specify + averaging method with mean_type argument. write(fname) Write spiking activity to a collection of spike trial files. """ @@ -431,6 +434,72 @@ def update_types(self, gid_dict): spike_types += [list(spike_types_trial)] self._types = spike_types + def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'): + """Mean spike rates (Hz) by cell type. + + Parameters + ---------- + tstart : int | float | None + Value defining the start time of all trials. + tstop : int | float | None + Value defining the stop time of all trials. + gid_dict : dict of lists or range objects + Dictionary with keys 'evprox1', 'evdist1' etc. + containing the range of Cell or input IDs of different + cell or input types. + mean_type : str + 'all' : Average over trials and cells + Returns mean rate for cell types + 'trial' : Average over cell types + Returns trial mean rate for cell types + 'cell' : Average over individual cells + Returns trial mean rate for individual cells + + Returns + ------- + spike_rate : dict + Dictionary with keys 'L5_pyramidal', 'L5_basket', etc. + """ + cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] + spike_rates = dict() + + if mean_type not in ['all', 'trial', 'cell']: + raise ValueError("Invalid mean_type. Valid arguments include " + f"'all', 'trial', or 'cell'. Got {mean_type}") + + # Validate tstart, tstop + if not isinstance(tstart, (int, float)) or not isinstance( + tstop, (int, float)): + raise ValueError('tstart and tstop must be of type int or float') + elif tstop <= tstart: + raise ValueError('tstop must be greater than tstart') + + for cell_type in cell_types: + cell_type_gids = np.array(gid_dict[cell_type]) + n_trials, n_cells = len(self._times), len(cell_type_gids) + gid_spike_rate = np.zeros((n_trials, n_cells)) + + trial_data = zip(self._types, self._gids) + for trial_idx, (spike_types, spike_gids) in enumerate(trial_data): + trial_type_mask = np.in1d(spike_types, cell_type) + gids, gid_counts = np.unique(np.array( + spike_gids)[trial_type_mask], return_counts=True) + + gid_spike_rate[trial_idx, np.in1d(cell_type_gids, gids)] = ( + gid_counts / (tstop - tstart)) * 1000 + + if mean_type == 'all': + spike_rates[cell_type] = np.mean( + gid_spike_rate.mean(axis=1)) + if mean_type == 'trial': + spike_rates[cell_type] = np.mean( + gid_spike_rate, axis=1).tolist() + if mean_type == 'cell': + spike_rates[cell_type] = [gid_trial_rate.tolist() + for gid_trial_rate in gid_spike_rate] + + return spike_rates + def plot(self, ax=None, show=True): """Plot the aggregate spiking activity according to cell type. diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index e7a3bff47..dff21e894 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -88,6 +88,9 @@ def test_spikes(tmpdir): spiketimes = [[2.3456, 7.89], [4.2812, 93.2]] spikegids = [[1, 3], [5, 7]] spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']] + tstart, tstop = 0.1, 98.4 + gid_dict = {'L2_pyramidal': range(1, 2), 'L2_basket': range(3, 4), + 'L5_pyramidal': range(5, 6), 'L5_basket': range(7, 8)} spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) spikes.plot_hist(show=False) spikes.write(tmpdir.join('spk_%d.txt')) @@ -107,6 +110,8 @@ def test_spikes(tmpdir): spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids, types=spiketypes) + spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) + with pytest.raises(TypeError, match="spike_types should be str, " "list, dict, or None"): spikes.plot_hist(spike_types=1, show=False) @@ -123,6 +128,36 @@ def test_spikes(tmpdir): with pytest.raises(ValueError, match="No input types found for ABC"): spikes.plot_hist(spike_types='ABC', show=False) + with pytest.raises(ValueError, match="tstart and tstop must be of type " + "int or float"): + spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict=gid_dict) + + with pytest.raises(ValueError, match="tstop must be greater than tstart"): + spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict=gid_dict) + + with pytest.raises(ValueError, match="Invalid mean_type. Valid " + "arguments include 'all', 'trial', or 'cell'."): + spikes.mean_rates(tstart=tstart, tstop=tstop, gid_dict=gid_dict, + mean_type='ABC') + + test_rate = (1 / (tstop - tstart)) * 1000 + + assert spikes.mean_rates(tstart, tstop, gid_dict) == { + 'L5_pyramidal': test_rate / 2, + 'L5_basket': test_rate / 2, + 'L2_pyramidal': test_rate / 2, + 'L2_basket': test_rate / 2} + assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='trial') == { + 'L5_pyramidal': [0.0, test_rate], + 'L5_basket': [0.0, test_rate], + 'L2_pyramidal': [test_rate, 0.0], + 'L2_basket': [test_rate, 0.0]} + assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='cell') == { + 'L5_pyramidal': [[0.0], [test_rate]], + 'L5_basket': [[0.0], [test_rate]], + 'L2_pyramidal': [[test_rate], [0.0]], + 'L2_basket': [[test_rate], [0.0]]} + # Write spike file with no 'types' column # Check for gid_dict errors for fname in sorted(glob(str(tmpdir.join('spk_*.txt')))):