From 2636083d4dec486f2ed9c9f8f26b35877de795bc Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Thu, 27 Aug 2020 15:45:05 -0400 Subject: [PATCH 01/20] Function to get spike rates --- hnn_core/network.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/hnn_core/network.py b/hnn_core/network.py index cb71dda25..34c35398e 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -431,6 +431,24 @@ def update_types(self, gid_dict): spike_types += [list(spike_types_trial)] self._types = spike_types + def get_spike_rates(self): + """Spike rates by cell type. + + Returns + ------- + spike_rate : dict + Dictionary with keys 'L5_pyramidal', 'L5_basket', etc. + """ + cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] + spike_times = np.array(sum(spikes._times, [])) + spike_types = np.array(sum(spikes._types, [])) + spike_rates = dict() + tstart, tstop = min(spike_times), max(spike_times) + for cell_type in cell_types: + spike_times_cells = spike_times[spike_types == cell_type] + spike_rates[cell_type] = len(spike_times_cells) / (tstop - tstart) + return spike_rates + def plot(self, ax=None, show=True): """Plot the aggregate spiking activity according to cell type. From d7ec743f2fd88a572abffe50fd50ad278dc9a8d3 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Thu, 27 Aug 2020 16:07:16 -0400 Subject: [PATCH 02/20] FIX: compute spike rate per trial and then average across trials --- hnn_core/network.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 34c35398e..b4cf4c003 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -440,13 +440,16 @@ def get_spike_rates(self): Dictionary with keys 'L5_pyramidal', 'L5_basket', etc. """ cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] - spike_times = np.array(sum(spikes._times, [])) - spike_types = np.array(sum(spikes._types, [])) spike_rates = dict() - tstart, tstop = min(spike_times), max(spike_times) + all_spike_times = np.array(sum(self._times, [])) + tstart, tstop = min(all_spike_times), max(all_spike_times) for cell_type in cell_types: - spike_times_cells = spike_times[spike_types == cell_type] - spike_rates[cell_type] = len(spike_times_cells) / (tstop - tstart) + trial_spike_rate = list() + for spike_times, spike_types in zip(self._times, self._types): + spike_times_cells = spike_times[spike_types == cell_type] + trial_spike_rate.append( + len(spike_times_cells) / (tstop - tstart)) + spike_rates[cell_type] = np.mean(trial_spike_rate) return spike_rates def plot(self, ax=None, show=True): From 4a293446860994996d4d51fa53f8fe8de43e1434 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Thu, 3 Sep 2020 20:32:18 -0400 Subject: [PATCH 03/20] Rename/reorganize to compute spike rates for individual gids, add option to select averaging performed --- hnn_core/network.py | 46 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index b4cf4c003..dbdfa5f98 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -431,8 +431,18 @@ def update_types(self, gid_dict): spike_types += [list(spike_types_trial)] self._types = spike_types - def get_spike_rates(self): - """Spike rates by cell type. + def mean_rates(self, mean_type='all'): + """Mean spike rates (Hz) by cell type. + + Parameters + ---------- + mean_type : str + 'all' : Average over trials and cells + Returns mean rate for cell types + 'trial' : Average over cell types + Returns trial mean rate for cell types + 'cell' : Average over individual cells + Returns trial mean rate for individual cells Returns ------- @@ -442,14 +452,34 @@ def get_spike_rates(self): cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] spike_rates = dict() all_spike_times = np.array(sum(self._times, [])) + all_spike_types = np.array(sum(self._types, [])) + all_spike_gids = np.array(sum(self._gids, [])) tstart, tstop = min(all_spike_times), max(all_spike_times) + for cell_type in cell_types: - trial_spike_rate = list() - for spike_times, spike_types in zip(self._times, self._types): - spike_times_cells = spike_times[spike_types == cell_type] - trial_spike_rate.append( - len(spike_times_cells) / (tstop - tstart)) - spike_rates[cell_type] = np.mean(trial_spike_rate) + type_mask = np.in1d(all_spike_types, cell_type) + cell_type_gids = np.unique(all_spike_gids[type_mask]) + gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids))) + + trial_data = zip(self._types, self._gids) + for trial, (spike_types, spike_gids) in enumerate(trial_data): + trial_type_mask = np.in1d(spike_types, cell_type) + gid, gid_counts = np.unique(np.array( + spike_gids)[trial_type_mask], return_counts=True) + + gid_spike_rate[trial, cell_type_gids == gid] = (gid_counts / ( + tstop - tstart)) * 1000 + + if mean_type == 'all': + spike_rates[cell_type] = np.mean( + np.mean(gid_spike_rate, axis=1)) + if mean_type == 'trial': + spike_rates[cell_type] = np.mean( + gid_spike_rate, axis=1).tolist() + if mean_type == 'cell': + spike_rates[cell_type] = [gid_trial_rate.tolist() + for gid_trial_rate in gid_spike_rate] + return spike_rates def plot(self, ax=None, show=True): From 4f082a74ceb487d7b9999843fae4d06c197e5c51 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sun, 6 Sep 2020 20:39:04 -0400 Subject: [PATCH 04/20] Add ValueError check and corresponding test for mean_type argument --- hnn_core/network.py | 4 ++++ hnn_core/tests/test_network.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/hnn_core/network.py b/hnn_core/network.py index dbdfa5f98..0718dfc3e 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -456,6 +456,10 @@ def mean_rates(self, mean_type='all'): all_spike_gids = np.array(sum(self._gids, [])) tstart, tstop = min(all_spike_times), max(all_spike_times) + if mean_type not in ['all', 'trial', 'cell']: + raise ValueError("Invalid mean_type. Valid arguments include " + "'all', 'trial', or 'cell'.") + for cell_type in cell_types: type_mask = np.in1d(all_spike_types, cell_type) cell_type_gids = np.unique(all_spike_gids[type_mask]) diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index e7a3bff47..763d595ba 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -123,6 +123,10 @@ def test_spikes(tmpdir): with pytest.raises(ValueError, match="No input types found for ABC"): spikes.plot_hist(spike_types='ABC', show=False) + with pytest.raises(ValueError, match="Invalid mean_type. Valid " + "arguments include 'all', 'trial', or 'cell'."): + spikes.mean_rates(mean_type='ABC') + # Write spike file with no 'types' column # Check for gid_dict errors for fname in sorted(glob(str(tmpdir.join('spk_*.txt')))): From 48665d9cabbf01c6149c28ed61f02566f9899677 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Mon, 7 Sep 2020 12:18:15 -0400 Subject: [PATCH 05/20] Add tstart and tstop attributes to Spikes class, update Network/Spikes functions and tests accordingly --- hnn_core/network.py | 101 ++++++++++++++++++++++++++++----- hnn_core/parallel_backends.py | 2 + hnn_core/tests/test_network.py | 22 +++++-- 3 files changed, 107 insertions(+), 18 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 0718dfc3e..5cad75273 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -12,7 +12,7 @@ from .viz import plot_spikes_hist, plot_spikes_raster, plot_cells -def read_spikes(fname, gid_dict=None): +def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): """Read spiking activity from a collection of spike trial files. Parameters @@ -25,6 +25,8 @@ def read_spikes(fname, gid_dict=None): containing the range of Cell or input IDs of different cell or input types. If None, each spike file must contain a 3rd column for spike type. + tstart, tstop : float | None + Values defining the start and stop times of all trials. Returns ---------- @@ -35,27 +37,58 @@ def read_spikes(fname, gid_dict=None): spike_times = [] spike_gids = [] spike_types = [] +<<<<<<< HEAD for file in sorted(glob(str(fname))): +======= + spike_tstart = [] + spike_tstop = [] + for file in sorted(glob(fname)): +>>>>>>> Add tstart and tstop attributes to Spikes class, update Network/Spikes functions and tests accordingly spike_trial = np.loadtxt(file, dtype=str) spike_times += [list(spike_trial[:, 0].astype(float))] spike_gids += [list(spike_trial[:, 1].astype(int))] # Note that legacy HNN 'spk.txt' files don't contain a 3rd column for - # spike type. If reading a legacy version, validate that a gid_dict is - # provided. - if spike_trial.shape[1] == 3: + # spike type, or a 4th and 5th column for tstart and tstop. If reading + # a legacy version, validate that gid_dict, tstart, and tstop provided. + if spike_trial.shape[1] >= 3: spike_types += [list(spike_trial[:, 2].astype(str))] - else: + elif spike_trial.shape[1] == 2: if gid_dict is None: raise ValueError("gid_dict must be provided if spike types " "are unspecified in the file %s" % (file,)) spike_types += [[]] - spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types) + if spike_trial.shape[1] == 5: + spike_tstart += np.unique(spike_trial[:, 3].astype(float)).tolist() + spike_tstop += np.unique(spike_trial[:, 4].astype(float)).tolist() + else: + if tstart is None or tstop is None: + raise ValueError("tstart and tstop must be provided if values " + "are unspecified in the file %s" % (file,)) + + if len(np.unique(spike_tstart)) > 1 or len(np.unique(spike_tstop)) > 1: + raise ValueError("tstart and tstop must match across files.") + + elif len(np.unique(spike_tstart)) == 1 and len( + np.unique(spike_tstop)) == 1: + spike_tstart = spike_tstart[0] + spike_tstop = spike_tstop[0] + else: + spike_tstart = np.min(sum(spike_times, [])) + spike_tstop = np.max(sum(spike_times, [])) + + spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types, + tstart=spike_tstart, tstop=spike_tstop) + if gid_dict is not None: spikes.update_types(gid_dict) - return Spikes(times=spike_times, gids=spike_gids, types=spike_types) + if tstart is not None and tstop is not None: + spikes.update_trial_bounds(spike_tstart, spike_tstop) + + return Spikes(times=spike_times, gids=spike_gids, types=spike_types, + tstart=spike_tstart, tstop=spike_tstop) def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff=1307.4): @@ -313,6 +346,8 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network().gid_dict. + tstart, tstop : int | float | None + Values defining the start and stop times of all trials. Attributes ---------- @@ -328,6 +363,8 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network::gid_dict. + tstart, tstop : int | float | None + Values defining the start and stop times of all trials. Methods ------- @@ -336,11 +373,15 @@ class Spikes(object): plot(ax=None, show=True) Plot and return a matplotlib Figure object showing the aggregate network spiking activity according to cell type. + mean_rates(mean_type='all') + Calculate mean firing rate for each cell type. Specify + averaging method with mean_type argument. write(fname) Write spiking activity to a collection of spike trial files. """ - def __init__(self, times=None, gids=None, types=None): + def __init__(self, times=None, gids=None, types=None, + tstart=None, tstop=None): if times is None: times = list() if gids is None: @@ -370,6 +411,8 @@ def __init__(self, times=None, gids=None, types=None): self._times = times self._gids = gids self._types = types + self._tstart = tstart + self._tstop = tstop def __repr__(self): class_name = self.__class__.__name__ @@ -386,7 +429,9 @@ def __eq__(self, other): for trial in other._times] return (times_self == times_other and self._gids == other._gids and - self._types == other._types) + self._types == other._types and + self._tstart == other._tstart and + self._tstop == other._tstop) @property def times(self): @@ -400,6 +445,14 @@ def gids(self): def types(self): return self._types + @property + def tstart(self): + return self._tstart + + @property + def tstop(self): + return self._tstop + def update_types(self, gid_dict): """Update spike types in the current instance of Spikes. @@ -431,6 +484,25 @@ def update_types(self, gid_dict): spike_types += [list(spike_types_trial)] self._types = spike_types + def update_trial_bounds(self, tstart, tstop): + """Update tstart and tstop in the current instance of Spikes. + + Parameters + ---------- + tstart, tstop : float | None + Values defining the start and stop times of all trials. + + """ + + # Validate tstart, tstop + if not isinstance(tstart, (int, float)) or not isinstance( + tstop, (int, float)): + raise ValueError('tstart and tstop must be of type int or float') + elif tstop <= tstart: + raise ValueError('tstop must be greater than tstart') + self._tstart = tstart + self._tstop = tstop + def mean_rates(self, mean_type='all'): """Mean spike rates (Hz) by cell type. @@ -451,10 +523,8 @@ def mean_rates(self, mean_type='all'): """ cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] spike_rates = dict() - all_spike_times = np.array(sum(self._times, [])) all_spike_types = np.array(sum(self._types, [])) all_spike_gids = np.array(sum(self._gids, [])) - tstart, tstop = min(all_spike_times), max(all_spike_times) if mean_type not in ['all', 'trial', 'cell']: raise ValueError("Invalid mean_type. Valid arguments include " @@ -472,7 +542,7 @@ def mean_rates(self, mean_type='all'): spike_gids)[trial_type_mask], return_counts=True) gid_spike_rate[trial, cell_type_gids == gid] = (gid_counts / ( - tstop - tstart)) * 1000 + self._tstop - self._tstart)) * 1000 if mean_type == 'all': spike_rates[cell_type] = np.mean( @@ -550,12 +620,15 @@ def write(self, fname): 1) spike time (s), 2) spike gid, and 3) gid type + 4) tstart + 5) tstop """ for trial_idx in range(len(self._times)): with open(str(fname) % (trial_idx,), 'w') as f: for spike_idx in range(len(self._times[trial_idx])): - f.write('{:.3f}\t{}\t{}\n'.format( + f.write('{:.3f}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format( self._times[trial_idx][spike_idx], int(self._gids[trial_idx][spike_idx]), - self._types[trial_idx][spike_idx])) + self._types[trial_idx][spike_idx], + float(self._tstart), float(self._tstop))) diff --git a/hnn_core/parallel_backends.py b/hnn_core/parallel_backends.py index e2d753bbe..798938702 100644 --- a/hnn_core/parallel_backends.py +++ b/hnn_core/parallel_backends.py @@ -24,6 +24,8 @@ def _gather_trial_data(sim_data, net, n_trials): """ dpls = [] + net.spikes.update_trial_bounds(0.0, net.params['tstop']) + for idx in range(n_trials): dpls.append(sim_data[idx][0]) spikedata = sim_data[idx][1] diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 763d595ba..14ac92ed5 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -88,7 +88,8 @@ def test_spikes(tmpdir): spiketimes = [[2.3456, 7.89], [4.2812, 93.2]] spikegids = [[1, 3], [5, 7]] spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']] - spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) + spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes, + tstart=0.1, tstop=98.4) spikes.plot_hist(show=False) spikes.write(tmpdir.join('spk_%d.txt')) assert spikes == read_spikes(tmpdir.join('spk_*.txt')) @@ -97,15 +98,28 @@ def test_spikes(tmpdir): with pytest.raises(TypeError, match="times should be a list of lists"): spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids, - types=spiketypes) + types=spiketypes, tstart=0.1, tstop=98.4) with pytest.raises(TypeError, match="times should be a list of lists"): - spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes) + spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes, + tstart=0.1, tstop=98.4) with pytest.raises(ValueError, match="times, gids, and types should be " "lists of the same length"): spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids, - types=spiketypes) + types=spiketypes, tstart=0.1, tstop=98.4) + + with pytest.raises(ValueError, match="tstart and tstop must be of type " + "int or float"): + spikes = Spikes() + spikes.update_trial_bounds(tstart=0.1, tstop='ABC') + + with pytest.raises(ValueError, match="tstop must be greater than tstart"): + spikes = Spikes() + spikes.update_trial_bounds(tstart=0.1, tstop=-1.0) + + spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes, + tstart=0.1, tstop=98.4) with pytest.raises(TypeError, match="spike_types should be str, " "list, dict, or None"): From 9821a09dcba1f2e5ada8ec414c56dd77b33b5a86 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Thu, 10 Sep 2020 20:25:57 -0400 Subject: [PATCH 06/20] Change variable name --- hnn_core/network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 5cad75273..4c0b0f96c 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -536,13 +536,13 @@ def mean_rates(self, mean_type='all'): gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids))) trial_data = zip(self._types, self._gids) - for trial, (spike_types, spike_gids) in enumerate(trial_data): + for trial_idx, (spike_types, spike_gids) in enumerate(trial_data): trial_type_mask = np.in1d(spike_types, cell_type) gid, gid_counts = np.unique(np.array( spike_gids)[trial_type_mask], return_counts=True) - gid_spike_rate[trial, cell_type_gids == gid] = (gid_counts / ( - self._tstop - self._tstart)) * 1000 + gid_spike_rate[trial_idx, cell_type_gids == gid] = ( + gid_counts / (self._tstop - self._tstart)) * 1000 if mean_type == 'all': spike_rates[cell_type] = np.mean( From 760feb22451c91bc933ea4a4afd4118d5d248f56 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Thu, 10 Sep 2020 20:30:07 -0400 Subject: [PATCH 07/20] DOC: Fix tstart/tstop descriptions --- hnn_core/network.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 4c0b0f96c..a330b51c9 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -346,8 +346,10 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network().gid_dict. - tstart, tstop : int | float | None - Values defining the start and stop times of all trials. + tstart : int | float | None + Value defining the start time of all trials. + tstop : int | float | None + Value defining the stop time of all trials. Attributes ---------- @@ -363,8 +365,10 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network::gid_dict. - tstart, tstop : int | float | None - Values defining the start and stop times of all trials. + tstart : int | float | None + Value defining the start time of all trials. + tstop : int | float | None + Value defining the stop time of all trials. Methods ------- @@ -489,8 +493,10 @@ def update_trial_bounds(self, tstart, tstop): Parameters ---------- - tstart, tstop : float | None - Values defining the start and stop times of all trials. + tstart : int | float | None + Value defining the start time of all trials. + tstop : int | float | None + Value defining the stop time of all trials. """ From b40cf73c35c7d4c016befe8ea026e4319e5cb6f6 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Thu, 10 Sep 2020 20:42:45 -0400 Subject: [PATCH 08/20] Remove unecessary else clause from read_spikes() --- hnn_core/network.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index a330b51c9..6a0597d9f 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -74,9 +74,6 @@ def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): np.unique(spike_tstop)) == 1: spike_tstart = spike_tstart[0] spike_tstop = spike_tstop[0] - else: - spike_tstart = np.min(sum(spike_times, [])) - spike_tstop = np.max(sum(spike_times, [])) spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types, tstart=spike_tstart, tstop=spike_tstop) From 2338feb26165c196e2281b2fad87470bf6ff2f10 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 12 Sep 2020 15:53:15 -0400 Subject: [PATCH 09/20] DOC: Add description clarifying tstart/tstop only necessary for legacy files --- hnn_core/network.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 6a0597d9f..158c01341 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -25,8 +25,12 @@ def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): containing the range of Cell or input IDs of different cell or input types. If None, each spike file must contain a 3rd column for spike type. - tstart, tstop : float | None - Values defining the start and stop times of all trials. + tstart : int | float | None + Value defining the start time of all trials. + Only relevant for legacy files. + tstop : int | float | None + Value defining the stop time of all trials. + Only relevant for legacy files. Returns ---------- @@ -345,8 +349,10 @@ class Spikes(object): Each gid corresponds to a type via Network().gid_dict. tstart : int | float | None Value defining the start time of all trials. + Only relevant for legacy files. tstop : int | float | None Value defining the stop time of all trials. + Only relevant for legacy files. Attributes ---------- @@ -364,8 +370,10 @@ class Spikes(object): Each gid corresponds to a type via Network::gid_dict. tstart : int | float | None Value defining the start time of all trials. + Only relevant for legacy files. tstop : int | float | None Value defining the stop time of all trials. + Only relevant for legacy files. Methods ------- @@ -492,8 +500,10 @@ def update_trial_bounds(self, tstart, tstop): ---------- tstart : int | float | None Value defining the start time of all trials. + Only relevant for legacy files. tstop : int | float | None Value defining the stop time of all trials. + Only relevant for legacy files. """ From 9eba0c30588aea99619b473e5f454f4205b6bb11 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 12 Sep 2020 18:49:23 -0400 Subject: [PATCH 10/20] TST: Raise error if tstart/tstop present in file and user provides values manually --- hnn_core/network.py | 10 ++++++---- hnn_core/tests/test_network.py | 3 +++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 158c01341..ee66780ca 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -64,12 +64,14 @@ def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): spike_types += [[]] if spike_trial.shape[1] == 5: + if tstart is not None or tstop is not None: + raise ValueError("tstart and tstop are specified in both the " + "file %s and function call. " % (file,)) spike_tstart += np.unique(spike_trial[:, 3].astype(float)).tolist() spike_tstop += np.unique(spike_trial[:, 4].astype(float)).tolist() - else: - if tstart is None or tstop is None: - raise ValueError("tstart and tstop must be provided if values " - "are unspecified in the file %s" % (file,)) + elif tstart is None or tstop is None: + raise ValueError("tstart and tstop must be provided if values " + "are unspecified in the file %s" % (file,)) if len(np.unique(spike_tstart)) > 1 or len(np.unique(spike_tstop)) > 1: raise ValueError("tstart and tstop must match across files.") diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 14ac92ed5..c2c31c2e5 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -95,6 +95,9 @@ def test_spikes(tmpdir): assert spikes == read_spikes(tmpdir.join('spk_*.txt')) assert ("Spikes | 2 simulation trials" in repr(spikes)) + with pytest.raises(ValueError, match="tstart and tstop are specified in " + "both the file /tmp/spk_0.txt and function call."): + read_spikes('/tmp/spk_*.txt', tstart=0.1, tstop=98.4) with pytest.raises(TypeError, match="times should be a list of lists"): spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids, From 7777891c40157c83db83ce57b242815a20743931 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 12 Sep 2020 19:07:27 -0400 Subject: [PATCH 11/20] TST: Add tests asserting correct calculation of spikes.mean_rates() --- hnn_core/tests/test_network.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index c2c31c2e5..8ea20c4c1 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -95,6 +95,21 @@ def test_spikes(tmpdir): assert spikes == read_spikes(tmpdir.join('spk_*.txt')) assert ("Spikes | 2 simulation trials" in repr(spikes)) + assert spikes.mean_rates() == { + 'L5_pyramidal': 5.08646998982706, + 'L5_basket': 5.08646998982706, + 'L2_pyramidal': 5.08646998982706, + 'L2_basket': 5.08646998982706} + assert spikes.mean_rates(mean_type='trial') == { + 'L5_pyramidal': [0.0, 10.17293997965412], + 'L5_basket': [0.0, 10.17293997965412], + 'L2_pyramidal': [10.17293997965412, 0.0], + 'L2_basket': [10.17293997965412, 0.0]} + assert spikes.mean_rates(mean_type='cell') == { + 'L5_pyramidal': [[0.0], [10.17293997965412]], + 'L5_basket': [[0.0], [10.17293997965412]], + 'L2_pyramidal': [[10.17293997965412], [0.0]], + 'L2_basket': [[10.17293997965412], [0.0]]} with pytest.raises(ValueError, match="tstart and tstop are specified in " "both the file /tmp/spk_0.txt and function call."): read_spikes('/tmp/spk_*.txt', tstart=0.1, tstop=98.4) From 37f9ba76808ef13e312d114fa1a54c861d00b958 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Wed, 16 Sep 2020 19:10:46 -0700 Subject: [PATCH 12/20] DOC: Update whats_new.rst --- doc/whats_new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 154ac0bd8..ff1deb33d 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -29,6 +29,8 @@ Changelog - Update plot_hist_input() to plot_spikes_hist() which can plot histogram of spikes for any cell type, by `Nick Tolley`_ in `#157 `_ +- Add function to compute mean spike rates with user specified calculation type, by `Nick Tolley` and `Mainak Jas`_ in `#155 `_ + Bug ~~~ From 343b68a06eb2c06760966d4235053e7693e18b53 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Fri, 25 Sep 2020 20:14:28 -0700 Subject: [PATCH 13/20] Remove tstart and tstop as attributes, instead pass as parameters to mean_rates() --- hnn_core/network.py | 124 ++++++--------------------------- hnn_core/parallel_backends.py | 2 - hnn_core/tests/test_network.py | 60 +++++++--------- 3 files changed, 48 insertions(+), 138 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index ee66780ca..24c43c2ff 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -25,12 +25,6 @@ def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): containing the range of Cell or input IDs of different cell or input types. If None, each spike file must contain a 3rd column for spike type. - tstart : int | float | None - Value defining the start time of all trials. - Only relevant for legacy files. - tstop : int | float | None - Value defining the stop time of all trials. - Only relevant for legacy files. Returns ---------- @@ -41,57 +35,27 @@ def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): spike_times = [] spike_gids = [] spike_types = [] -<<<<<<< HEAD for file in sorted(glob(str(fname))): -======= - spike_tstart = [] - spike_tstop = [] - for file in sorted(glob(fname)): ->>>>>>> Add tstart and tstop attributes to Spikes class, update Network/Spikes functions and tests accordingly spike_trial = np.loadtxt(file, dtype=str) spike_times += [list(spike_trial[:, 0].astype(float))] spike_gids += [list(spike_trial[:, 1].astype(int))] # Note that legacy HNN 'spk.txt' files don't contain a 3rd column for - # spike type, or a 4th and 5th column for tstart and tstop. If reading - # a legacy version, validate that gid_dict, tstart, and tstop provided. - if spike_trial.shape[1] >= 3: + # spike type. If reading a legacy version, validate that a gid_dict is + # provided. + if spike_trial.shape[1] == 3: spike_types += [list(spike_trial[:, 2].astype(str))] - elif spike_trial.shape[1] == 2: + else: if gid_dict is None: raise ValueError("gid_dict must be provided if spike types " "are unspecified in the file %s" % (file,)) spike_types += [[]] - if spike_trial.shape[1] == 5: - if tstart is not None or tstop is not None: - raise ValueError("tstart and tstop are specified in both the " - "file %s and function call. " % (file,)) - spike_tstart += np.unique(spike_trial[:, 3].astype(float)).tolist() - spike_tstop += np.unique(spike_trial[:, 4].astype(float)).tolist() - elif tstart is None or tstop is None: - raise ValueError("tstart and tstop must be provided if values " - "are unspecified in the file %s" % (file,)) - - if len(np.unique(spike_tstart)) > 1 or len(np.unique(spike_tstop)) > 1: - raise ValueError("tstart and tstop must match across files.") - - elif len(np.unique(spike_tstart)) == 1 and len( - np.unique(spike_tstop)) == 1: - spike_tstart = spike_tstart[0] - spike_tstop = spike_tstop[0] - - spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types, - tstart=spike_tstart, tstop=spike_tstop) - + spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types) if gid_dict is not None: spikes.update_types(gid_dict) - if tstart is not None and tstop is not None: - spikes.update_trial_bounds(spike_tstart, spike_tstop) - - return Spikes(times=spike_times, gids=spike_gids, types=spike_types, - tstart=spike_tstart, tstop=spike_tstop) + return Spikes(times=spike_times, gids=spike_gids, types=spike_types) def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff=1307.4): @@ -349,12 +313,6 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network().gid_dict. - tstart : int | float | None - Value defining the start time of all trials. - Only relevant for legacy files. - tstop : int | float | None - Value defining the stop time of all trials. - Only relevant for legacy files. Attributes ---------- @@ -370,12 +328,6 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network::gid_dict. - tstart : int | float | None - Value defining the start time of all trials. - Only relevant for legacy files. - tstop : int | float | None - Value defining the stop time of all trials. - Only relevant for legacy files. Methods ------- @@ -384,15 +336,14 @@ class Spikes(object): plot(ax=None, show=True) Plot and return a matplotlib Figure object showing the aggregate network spiking activity according to cell type. - mean_rates(mean_type='all') + mean_rates(tstart, tstop, gid_dict, mean_type='all') Calculate mean firing rate for each cell type. Specify averaging method with mean_type argument. write(fname) Write spiking activity to a collection of spike trial files. """ - def __init__(self, times=None, gids=None, types=None, - tstart=None, tstop=None): + def __init__(self, times=None, gids=None, types=None): if times is None: times = list() if gids is None: @@ -422,8 +373,6 @@ def __init__(self, times=None, gids=None, types=None, self._times = times self._gids = gids self._types = types - self._tstart = tstart - self._tstop = tstop def __repr__(self): class_name = self.__class__.__name__ @@ -440,9 +389,7 @@ def __eq__(self, other): for trial in other._times] return (times_self == times_other and self._gids == other._gids and - self._types == other._types and - self._tstart == other._tstart and - self._tstop == other._tstop) + self._types == other._types) @property def times(self): @@ -456,14 +403,6 @@ def gids(self): def types(self): return self._types - @property - def tstart(self): - return self._tstart - - @property - def tstop(self): - return self._tstop - def update_types(self, gid_dict): """Update spike types in the current instance of Spikes. @@ -495,34 +434,19 @@ def update_types(self, gid_dict): spike_types += [list(spike_types_trial)] self._types = spike_types - def update_trial_bounds(self, tstart, tstop): - """Update tstart and tstop in the current instance of Spikes. + def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'): + """Mean spike rates (Hz) by cell type. Parameters ---------- tstart : int | float | None Value defining the start time of all trials. - Only relevant for legacy files. tstop : int | float | None Value defining the stop time of all trials. - Only relevant for legacy files. - - """ - - # Validate tstart, tstop - if not isinstance(tstart, (int, float)) or not isinstance( - tstop, (int, float)): - raise ValueError('tstart and tstop must be of type int or float') - elif tstop <= tstart: - raise ValueError('tstop must be greater than tstart') - self._tstart = tstart - self._tstop = tstop - - def mean_rates(self, mean_type='all'): - """Mean spike rates (Hz) by cell type. - - Parameters - ---------- + gid_dict : dict of lists or range objects + Dictionary with keys 'evprox1', 'evdist1' etc. + containing the range of Cell or input IDs of different + cell or input types. mean_type : str 'all' : Average over trials and cells Returns mean rate for cell types @@ -538,26 +462,23 @@ def mean_rates(self, mean_type='all'): """ cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] spike_rates = dict() - all_spike_types = np.array(sum(self._types, [])) - all_spike_gids = np.array(sum(self._gids, [])) if mean_type not in ['all', 'trial', 'cell']: raise ValueError("Invalid mean_type. Valid arguments include " "'all', 'trial', or 'cell'.") for cell_type in cell_types: - type_mask = np.in1d(all_spike_types, cell_type) - cell_type_gids = np.unique(all_spike_gids[type_mask]) + cell_type_gids = np.array(gid_dict[cell_type]) gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids))) trial_data = zip(self._types, self._gids) for trial_idx, (spike_types, spike_gids) in enumerate(trial_data): trial_type_mask = np.in1d(spike_types, cell_type) - gid, gid_counts = np.unique(np.array( + gids, gid_counts = np.unique(np.array( spike_gids)[trial_type_mask], return_counts=True) - gid_spike_rate[trial_idx, cell_type_gids == gid] = ( - gid_counts / (self._tstop - self._tstart)) * 1000 + gid_spike_rate[trial_idx, np.in1d(cell_type_gids, gids)] = ( + gid_counts / (tstop - tstart)) * 1000 if mean_type == 'all': spike_rates[cell_type] = np.mean( @@ -635,15 +556,12 @@ def write(self, fname): 1) spike time (s), 2) spike gid, and 3) gid type - 4) tstart - 5) tstop """ for trial_idx in range(len(self._times)): with open(str(fname) % (trial_idx,), 'w') as f: for spike_idx in range(len(self._times[trial_idx])): - f.write('{:.3f}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format( + f.write('{:.3f}\t{}\t{}\t\n'.format( self._times[trial_idx][spike_idx], int(self._gids[trial_idx][spike_idx]), - self._types[trial_idx][spike_idx], - float(self._tstart), float(self._tstop))) + self._types[trial_idx][spike_idx])) diff --git a/hnn_core/parallel_backends.py b/hnn_core/parallel_backends.py index 798938702..e2d753bbe 100644 --- a/hnn_core/parallel_backends.py +++ b/hnn_core/parallel_backends.py @@ -24,8 +24,6 @@ def _gather_trial_data(sim_data, net, n_trials): """ dpls = [] - net.spikes.update_trial_bounds(0.0, net.params['tstop']) - for idx in range(n_trials): dpls.append(sim_data[idx][0]) spikedata = sim_data[idx][1] diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 8ea20c4c1..14bac7c0f 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -88,56 +88,50 @@ def test_spikes(tmpdir): spiketimes = [[2.3456, 7.89], [4.2812, 93.2]] spikegids = [[1, 3], [5, 7]] spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']] - spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes, - tstart=0.1, tstop=98.4) + spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) spikes.plot_hist(show=False) spikes.write(tmpdir.join('spk_%d.txt')) assert spikes == read_spikes(tmpdir.join('spk_*.txt')) assert ("Spikes | 2 simulation trials" in repr(spikes)) - assert spikes.mean_rates() == { - 'L5_pyramidal': 5.08646998982706, - 'L5_basket': 5.08646998982706, - 'L2_pyramidal': 5.08646998982706, - 'L2_basket': 5.08646998982706} - assert spikes.mean_rates(mean_type='trial') == { - 'L5_pyramidal': [0.0, 10.17293997965412], - 'L5_basket': [0.0, 10.17293997965412], - 'L2_pyramidal': [10.17293997965412, 0.0], - 'L2_basket': [10.17293997965412, 0.0]} - assert spikes.mean_rates(mean_type='cell') == { - 'L5_pyramidal': [[0.0], [10.17293997965412]], - 'L5_basket': [[0.0], [10.17293997965412]], - 'L2_pyramidal': [[10.17293997965412], [0.0]], - 'L2_basket': [[10.17293997965412], [0.0]]} - with pytest.raises(ValueError, match="tstart and tstop are specified in " - "both the file /tmp/spk_0.txt and function call."): - read_spikes('/tmp/spk_*.txt', tstart=0.1, tstop=98.4) + # assert spikes.mean_rates() == { + # 'L5_pyramidal': 5.08646998982706, + # 'L5_basket': 5.08646998982706, + # 'L2_pyramidal': 5.08646998982706, + # 'L2_basket': 5.08646998982706} + # assert spikes.mean_rates(mean_type='trial') == { + # 'L5_pyramidal': [0.0, 10.17293997965412], + # 'L5_basket': [0.0, 10.17293997965412], + # 'L2_pyramidal': [10.17293997965412, 0.0], + # 'L2_basket': [10.17293997965412, 0.0]} + # assert spikes.mean_rates(mean_type='cell') == { + # 'L5_pyramidal': [[0.0], [10.17293997965412]], + # 'L5_basket': [[0.0], [10.17293997965412]], + # 'L2_pyramidal': [[10.17293997965412], [0.0]], + # 'L2_basket': [[10.17293997965412], [0.0]]} with pytest.raises(TypeError, match="times should be a list of lists"): spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids, - types=spiketypes, tstart=0.1, tstop=98.4) + types=spiketypes) with pytest.raises(TypeError, match="times should be a list of lists"): - spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes, - tstart=0.1, tstop=98.4) + spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes) with pytest.raises(ValueError, match="times, gids, and types should be " "lists of the same length"): spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids, - types=spiketypes, tstart=0.1, tstop=98.4) + types=spiketypes) - with pytest.raises(ValueError, match="tstart and tstop must be of type " - "int or float"): - spikes = Spikes() - spikes.update_trial_bounds(tstart=0.1, tstop='ABC') + # with pytest.raises(ValueError, match="tstart and tstop must be of type " + # "int or float"): + # spikes = Spikes() + # spikes.update_trial_bounds(tstart=0.1, tstop='ABC') - with pytest.raises(ValueError, match="tstop must be greater than tstart"): - spikes = Spikes() - spikes.update_trial_bounds(tstart=0.1, tstop=-1.0) + # with pytest.raises(ValueError, match="tstop must be greater than tstart"): + # spikes = Spikes() + # spikes.update_trial_bounds(tstart=0.1, tstop=-1.0) - spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes, - tstart=0.1, tstop=98.4) + spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) with pytest.raises(TypeError, match="spike_types should be str, " "list, dict, or None"): From 4c4c92fa7705035e8abaeeaf161a4020d1323798 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Fri, 25 Sep 2020 20:28:28 -0700 Subject: [PATCH 14/20] TST: Add tests validating correct tstart/tstop entry --- hnn_core/tests/test_network.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 14bac7c0f..4483b28bb 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -122,14 +122,14 @@ def test_spikes(tmpdir): spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids, types=spiketypes) - # with pytest.raises(ValueError, match="tstart and tstop must be of type " - # "int or float"): - # spikes = Spikes() - # spikes.update_trial_bounds(tstart=0.1, tstop='ABC') + with pytest.raises(ValueError, match="tstart and tstop must be of type " + "int or float"): + spikes = Spikes() + spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict={}) - # with pytest.raises(ValueError, match="tstop must be greater than tstart"): - # spikes = Spikes() - # spikes.update_trial_bounds(tstart=0.1, tstop=-1.0) + with pytest.raises(ValueError, match="tstop must be greater than tstart"): + spikes = Spikes() + spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict={}) spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) @@ -151,7 +151,7 @@ def test_spikes(tmpdir): with pytest.raises(ValueError, match="Invalid mean_type. Valid " "arguments include 'all', 'trial', or 'cell'."): - spikes.mean_rates(mean_type='ABC') + spikes.mean_rates(tstart=0.1, tstop=98.4, gid_dict={}, mean_type='ABC') # Write spike file with no 'types' column # Check for gid_dict errors From d69566307e86aaa886d80fb62b327e62f9d6f98d Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Fri, 25 Sep 2020 21:15:42 -0700 Subject: [PATCH 15/20] TST: Add tests asserting correct calculation of mean rates --- hnn_core/network.py | 7 +++++ hnn_core/tests/test_network.py | 55 ++++++++++++++++++---------------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 24c43c2ff..4c5fa7a6a 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -467,6 +467,13 @@ def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'): raise ValueError("Invalid mean_type. Valid arguments include " "'all', 'trial', or 'cell'.") + # Validate tstart, tstop + if not isinstance(tstart, (int, float)) or not isinstance( + tstop, (int, float)): + raise ValueError('tstart and tstop must be of type int or float') + elif tstop <= tstart: + raise ValueError('tstop must be greater than tstart') + for cell_type in cell_types: cell_type_gids = np.array(gid_dict[cell_type]) gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids))) diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 4483b28bb..dff21e894 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -88,27 +88,15 @@ def test_spikes(tmpdir): spiketimes = [[2.3456, 7.89], [4.2812, 93.2]] spikegids = [[1, 3], [5, 7]] spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']] + tstart, tstop = 0.1, 98.4 + gid_dict = {'L2_pyramidal': range(1, 2), 'L2_basket': range(3, 4), + 'L5_pyramidal': range(5, 6), 'L5_basket': range(7, 8)} spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) spikes.plot_hist(show=False) spikes.write(tmpdir.join('spk_%d.txt')) assert spikes == read_spikes(tmpdir.join('spk_*.txt')) assert ("Spikes | 2 simulation trials" in repr(spikes)) - # assert spikes.mean_rates() == { - # 'L5_pyramidal': 5.08646998982706, - # 'L5_basket': 5.08646998982706, - # 'L2_pyramidal': 5.08646998982706, - # 'L2_basket': 5.08646998982706} - # assert spikes.mean_rates(mean_type='trial') == { - # 'L5_pyramidal': [0.0, 10.17293997965412], - # 'L5_basket': [0.0, 10.17293997965412], - # 'L2_pyramidal': [10.17293997965412, 0.0], - # 'L2_basket': [10.17293997965412, 0.0]} - # assert spikes.mean_rates(mean_type='cell') == { - # 'L5_pyramidal': [[0.0], [10.17293997965412]], - # 'L5_basket': [[0.0], [10.17293997965412]], - # 'L2_pyramidal': [[10.17293997965412], [0.0]], - # 'L2_basket': [[10.17293997965412], [0.0]]} with pytest.raises(TypeError, match="times should be a list of lists"): spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids, @@ -122,15 +110,6 @@ def test_spikes(tmpdir): spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids, types=spiketypes) - with pytest.raises(ValueError, match="tstart and tstop must be of type " - "int or float"): - spikes = Spikes() - spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict={}) - - with pytest.raises(ValueError, match="tstop must be greater than tstart"): - spikes = Spikes() - spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict={}) - spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) with pytest.raises(TypeError, match="spike_types should be str, " @@ -149,9 +128,35 @@ def test_spikes(tmpdir): with pytest.raises(ValueError, match="No input types found for ABC"): spikes.plot_hist(spike_types='ABC', show=False) + with pytest.raises(ValueError, match="tstart and tstop must be of type " + "int or float"): + spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict=gid_dict) + + with pytest.raises(ValueError, match="tstop must be greater than tstart"): + spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict=gid_dict) + with pytest.raises(ValueError, match="Invalid mean_type. Valid " "arguments include 'all', 'trial', or 'cell'."): - spikes.mean_rates(tstart=0.1, tstop=98.4, gid_dict={}, mean_type='ABC') + spikes.mean_rates(tstart=tstart, tstop=tstop, gid_dict=gid_dict, + mean_type='ABC') + + test_rate = (1 / (tstop - tstart)) * 1000 + + assert spikes.mean_rates(tstart, tstop, gid_dict) == { + 'L5_pyramidal': test_rate / 2, + 'L5_basket': test_rate / 2, + 'L2_pyramidal': test_rate / 2, + 'L2_basket': test_rate / 2} + assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='trial') == { + 'L5_pyramidal': [0.0, test_rate], + 'L5_basket': [0.0, test_rate], + 'L2_pyramidal': [test_rate, 0.0], + 'L2_basket': [test_rate, 0.0]} + assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='cell') == { + 'L5_pyramidal': [[0.0], [test_rate]], + 'L5_basket': [[0.0], [test_rate]], + 'L2_pyramidal': [[test_rate], [0.0]], + 'L2_basket': [[test_rate], [0.0]]} # Write spike file with no 'types' column # Check for gid_dict errors From 4afc94aa1ca379326ac9e088cf97862af61ff04e Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 26 Sep 2020 10:46:35 -0700 Subject: [PATCH 16/20] Make trial and cell counts explicit for gid_spike_rate preallocation --- hnn_core/network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 4c5fa7a6a..ac4698523 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -476,7 +476,8 @@ def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'): for cell_type in cell_types: cell_type_gids = np.array(gid_dict[cell_type]) - gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids))) + n_trials, n_cells = len(self._times), len(cell_type_gids) + gid_spike_rate = np.zeros((n_trials, n_cells)) trial_data = zip(self._types, self._gids) for trial_idx, (spike_types, spike_gids) in enumerate(trial_data): From 7f536093b85c81465cbb46f23a8b463b2d0d3a70 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 26 Sep 2020 10:57:52 -0700 Subject: [PATCH 17/20] Minor changes to increase legibility, remove leftover tstart/tstop code --- hnn_core/network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index ac4698523..ee12f6716 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -12,7 +12,7 @@ from .viz import plot_spikes_hist, plot_spikes_raster, plot_cells -def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): +def read_spikes(fname, gid_dict=None): """Read spiking activity from a collection of spike trial files. Parameters @@ -490,7 +490,7 @@ def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'): if mean_type == 'all': spike_rates[cell_type] = np.mean( - np.mean(gid_spike_rate, axis=1)) + gid_spike_rate.mean(axis=1)) if mean_type == 'trial': spike_rates[cell_type] = np.mean( gid_spike_rate, axis=1).tolist() @@ -569,7 +569,7 @@ def write(self, fname): for trial_idx in range(len(self._times)): with open(str(fname) % (trial_idx,), 'w') as f: for spike_idx in range(len(self._times[trial_idx])): - f.write('{:.3f}\t{}\t{}\t\n'.format( + f.write('{:.3f}\t{}\t{}\n'.format( self._times[trial_idx][spike_idx], int(self._gids[trial_idx][spike_idx]), self._types[trial_idx][spike_idx])) From a180b7d97b510100c01a2c50651f3cf1d275be0b Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 26 Sep 2020 11:00:59 -0700 Subject: [PATCH 18/20] TST: Print invalid input during mean_type value error --- hnn_core/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index ee12f6716..a1ec6786c 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -465,7 +465,7 @@ def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'): if mean_type not in ['all', 'trial', 'cell']: raise ValueError("Invalid mean_type. Valid arguments include " - "'all', 'trial', or 'cell'.") + f"'all', 'trial', or 'cell'. Got {mean_type}") # Validate tstart, tstop if not isinstance(tstart, (int, float)) or not isinstance( From 85731408a2349b4d97c664ca37d0d5e4a24b3936 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 26 Sep 2020 11:33:04 -0700 Subject: [PATCH 19/20] Update example with mean spike rate calculation --- examples/plot_simulate_evoked.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/plot_simulate_evoked.py b/examples/plot_simulate_evoked.py index befc6b522..82202d229 100644 --- a/examples/plot_simulate_evoked.py +++ b/examples/plot_simulate_evoked.py @@ -66,6 +66,19 @@ spikes = read_spikes(op.join(tmp_dir_name, 'spk_*.txt')) spikes.plot() +############################################################################### +# We can additionally calculate the mean spike rates for each cell class by +# specifying a time window with tstart and tstop. +all_rates = spikes.mean_rates(tstart=0, tstop=170, gid_dict=net.gid_dict, + mean_type='all') +trial_rates = spikes.mean_rates(tstart=0, tstop=170, gid_dict=net.gid_dict, + mean_type='trial') +print('Mean spike rates across trials:') +print(all_rates) +print('Mean spike rates for individual trials:') +print(trial_rates) + + ############################################################################### # Now, let us try to make the exogenous driving inputs to the cells # synchronous and see what happens From f24257dfeed669a6600c471c915bb026b26191f7 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 26 Sep 2020 11:38:37 -0700 Subject: [PATCH 20/20] DOC: Fix hyperlink --- doc/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index ff1deb33d..904011536 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -29,7 +29,7 @@ Changelog - Update plot_hist_input() to plot_spikes_hist() which can plot histogram of spikes for any cell type, by `Nick Tolley`_ in `#157 `_ -- Add function to compute mean spike rates with user specified calculation type, by `Nick Tolley` and `Mainak Jas`_ in `#155 `_ +- Add function to compute mean spike rates with user specified calculation type, by `Nick Tolley`_ and `Mainak Jas`_ in `#155 `_ Bug ~~~