diff --git a/hnn_core/network.py b/hnn_core/network.py index 43a2c30d6..2e95b4aff 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -12,7 +12,7 @@ from .viz import plot_spikes_hist, plot_spikes_raster, plot_cells -def read_spikes(fname, gid_dict=None): +def read_spikes(fname, gid_dict=None, tstart=None, tstop=None): """Read spiking activity from a collection of spike trial files. Parameters @@ -25,6 +25,8 @@ def read_spikes(fname, gid_dict=None): containing the range of Cell or input IDs of different cell or input types. If None, each spike file must contain a 3rd column for spike type. + tstart, tstop : float | None + Values defining the start and stop times of all trials. Returns ---------- @@ -35,27 +37,54 @@ def read_spikes(fname, gid_dict=None): spike_times = [] spike_gids = [] spike_types = [] + spike_tstart = [] + spike_tstop = [] for file in sorted(glob(fname)): spike_trial = np.loadtxt(file, dtype=str) spike_times += [list(spike_trial[:, 0].astype(float))] spike_gids += [list(spike_trial[:, 1].astype(int))] # Note that legacy HNN 'spk.txt' files don't contain a 3rd column for - # spike type. If reading a legacy version, validate that a gid_dict is - # provided. - if spike_trial.shape[1] == 3: + # spike type, or a 4th and 5th column for tstart and tstop. If reading + # a legacy version, validate that gid_dict, tstart, and tstop provided. + if spike_trial.shape[1] >= 3: spike_types += [list(spike_trial[:, 2].astype(str))] - else: + elif spike_trial.shape[1] == 2: if gid_dict is None: raise ValueError("gid_dict must be provided if spike types " "are unspecified in the file %s" % (file,)) spike_types += [[]] - spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types) + if spike_trial.shape[1] == 5: + spike_tstart += np.unique(spike_trial[:, 3].astype(float)).tolist() + spike_tstop += np.unique(spike_trial[:, 4].astype(float)).tolist() + else: + if tstart is None or tstop is None: + raise ValueError("tstart and tstop must be provided if values " + "are unspecified in the file %s" % (file,)) + + if len(np.unique(spike_tstart)) > 1 or len(np.unique(spike_tstop)) > 1: + raise ValueError("tstart and tstop must match across files.") + + elif len(np.unique(spike_tstart)) == 1 and len( + np.unique(spike_tstop)) == 1: + spike_tstart = spike_tstart[0] + spike_tstop = spike_tstop[0] + else: + spike_tstart = np.min(sum(spike_times, [])) + spike_tstop = np.max(sum(spike_times, [])) + + spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types, + tstart=spike_tstart, tstop=spike_tstop) + if gid_dict is not None: spikes.update_types(gid_dict) - return Spikes(times=spike_times, gids=spike_gids, types=spike_types) + if tstart is not None and tstop is not None: + spikes.update_trial_bounds(spike_tstart, spike_tstop) + + return Spikes(times=spike_times, gids=spike_gids, types=spike_types, + tstart=spike_tstart, tstop=spike_tstop) def _create_coords(n_pyr_x, n_pyr_y, n_common_feeds, p_unique_keys, @@ -330,6 +359,8 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network().gid_dict. + tstart, tstop : int | float | None + Values defining the start and stop times of all trials. Attributes ---------- @@ -345,6 +376,8 @@ class Spikes(object): The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occured at the corresonding time stamp. Each gid corresponds to a type via Network::gid_dict. + tstart, tstop : int | float | None + Values defining the start and stop times of all trials. Methods ------- @@ -353,11 +386,15 @@ class Spikes(object): plot(ax=None, show=True) Plot and return a matplotlib Figure object showing the aggregate network spiking activity according to cell type. + mean_rates(mean_type='all') + Calculate mean firing rate for each cell type. Specify + averaging method with mean_type argument. write(fname) Write spiking activity to a collection of spike trial files. """ - def __init__(self, times=None, gids=None, types=None): + def __init__(self, times=None, gids=None, types=None, + tstart=None, tstop=None): if times is None: times = list() if gids is None: @@ -387,6 +424,8 @@ def __init__(self, times=None, gids=None, types=None): self._times = times self._gids = gids self._types = types + self._tstart = tstart + self._tstop = tstop def __repr__(self): class_name = self.__class__.__name__ @@ -403,7 +442,9 @@ def __eq__(self, other): for trial in other._times] return (times_self == times_other and self._gids == other._gids and - self._types == other._types) + self._types == other._types and + self._tstart == other._tstart and + self._tstop == other._tstop) @property def times(self): @@ -417,6 +458,14 @@ def gids(self): def types(self): return self._types + @property + def tstart(self): + return self._tstart + + @property + def tstop(self): + return self._tstop + def update_types(self, gid_dict): """Update spike types in the current instance of Spikes. @@ -448,6 +497,25 @@ def update_types(self, gid_dict): spike_types += [list(spike_types_trial)] self._types = spike_types + def update_trial_bounds(self, tstart, tstop): + """Update tstart and tstop in the current instance of Spikes. + + Parameters + ---------- + tstart, tstop : float | None + Values defining the start and stop times of all trials. + + """ + + # Validate tstart, tstop + if not isinstance(tstart, (int, float)) or not isinstance( + tstop, (int, float)): + raise ValueError('tstart and tstop must be of type int or float') + elif tstop <= tstart: + raise ValueError('tstop must be greater than tstart') + self._tstart = tstart + self._tstop = tstop + def mean_rates(self, mean_type='all'): """Mean spike rates (Hz) by cell type. @@ -468,10 +536,8 @@ def mean_rates(self, mean_type='all'): """ cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] spike_rates = dict() - all_spike_times = np.array(sum(self._times, [])) all_spike_types = np.array(sum(self._types, [])) all_spike_gids = np.array(sum(self._gids, [])) - tstart, tstop = min(all_spike_times), max(all_spike_times) if mean_type not in ['all', 'trial', 'cell']: raise ValueError("Invalid mean_type. Valid arguments include " @@ -489,7 +555,7 @@ def mean_rates(self, mean_type='all'): spike_gids)[trial_type_mask], return_counts=True) gid_spike_rate[trial, cell_type_gids == gid] = (gid_counts / ( - tstop - tstart)) * 1000 + self._tstop - self._tstart)) * 1000 if mean_type == 'all': spike_rates[cell_type] = np.mean( @@ -567,12 +633,15 @@ def write(self, fname): 1) spike time (s), 2) spike gid, and 3) gid type + 4) tstart + 5) tstop """ for trial_idx in range(len(self._times)): with open(fname % (trial_idx,), 'w') as f: for spike_idx in range(len(self._times[trial_idx])): - f.write('{:.3f}\t{}\t{}\n'.format( + f.write('{:.3f}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format( self._times[trial_idx][spike_idx], int(self._gids[trial_idx][spike_idx]), - self._types[trial_idx][spike_idx])) + self._types[trial_idx][spike_idx], + float(self._tstart), float(self._tstop))) diff --git a/hnn_core/parallel_backends.py b/hnn_core/parallel_backends.py index 2ce714cc7..409a3c884 100644 --- a/hnn_core/parallel_backends.py +++ b/hnn_core/parallel_backends.py @@ -23,6 +23,8 @@ def _gather_trial_data(sim_data, net, n_trials): """ dpls = [] + net.spikes.update_trial_bounds(0.0, net.params['tstop']) + for idx in range(n_trials): dpls.append(sim_data[idx][0]) spikedata = sim_data[idx][1] diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index e9a68a83b..ce8e4a2d9 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -62,7 +62,8 @@ def test_spikes(): spiketimes = [[2.3456, 7.89], [4.2812, 93.2]] spikegids = [[1, 3], [5, 7]] spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']] - spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) + spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes, + tstart=0.1, tstop=98.4) spikes.plot_hist(show=False) spikes.write('/tmp/spk_%d.txt') assert spikes == read_spikes('/tmp/spk_*.txt') @@ -70,15 +71,28 @@ def test_spikes(): with pytest.raises(TypeError, match="times should be a list of lists"): spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids, - types=spiketypes) + types=spiketypes, tstart=0.1, tstop=98.4) with pytest.raises(TypeError, match="times should be a list of lists"): - spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes) + spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes, + tstart=0.1, tstop=98.4) with pytest.raises(ValueError, match="times, gids, and types should be " "lists of the same length"): spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids, - types=spiketypes) + types=spiketypes, tstart=0.1, tstop=98.4) + + with pytest.raises(ValueError, match="tstart and tstop must be of type " + "int or float"): + spikes = Spikes() + spikes.update_trial_bounds(tstart=0.1, tstop='ABC') + + with pytest.raises(ValueError, match="tstop must be greater than tstart"): + spikes = Spikes() + spikes.update_trial_bounds(tstart=0.1, tstop=-1.0) + + spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes, + tstart=0.1, tstop=98.4) with pytest.raises(TypeError, match="spike_types should be str, " "list, dict, or None"): @@ -107,9 +121,21 @@ def test_spikes(): np.savetxt(fname, times_gids_only, delimiter='\t', fmt='%s') with pytest.raises(ValueError, match="gid_dict must be provided if spike " "types are unspecified in the file /tmp/spk_0.txt"): - spikes = read_spikes('/tmp/spk_*.txt') + spikes = read_spikes('/tmp/spk_*.txt', tstart=0.1, tstop=98.4) with pytest.raises(ValueError, match="gid_dict should contain only " "disjoint sets of gid values"): gid_dict = {'L2_pyramidal': range(3), 'L2_basket': range(2, 4), 'L5_pyramidal': range(4, 6), 'L5_basket': range(6, 8)} + spikes = read_spikes('/tmp/spk_*.txt', gid_dict=gid_dict, + tstart=0.1, tstop=98.4) + + # Write spike file with no 'tstart' or 'tstop' columns + # Check for gid_dict errors + for fname in sorted(glob('/tmp/spk_*.txt')): + times_gids_types_only = np.loadtxt(fname, dtype=str)[:, (0, 1)] + np.savetxt(fname, times_gids_types_only, delimiter='\t', fmt='%s') + with pytest.raises(ValueError, match="tstart and tstop must be provided " + "if values are unspecified in the file /tmp/spk_0.txt"): + gid_dict = {'L2_pyramidal': 1, 'L2_basket': 3, + 'L5_pyramidal': 5, 'L5_basket': 7} spikes = read_spikes('/tmp/spk_*.txt', gid_dict=gid_dict)