Skip to content

Commit

Permalink
Add tstart and tstop attributes to Spikes class, update Network/Spike…
Browse files Browse the repository at this point in the history
…s functions and tests accordingly
  • Loading branch information
ntolley committed Sep 7, 2020
1 parent 929e3f9 commit b2ac4fe
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 19 deletions.
97 changes: 83 additions & 14 deletions hnn_core/network.py
Expand Up @@ -12,7 +12,7 @@
from .viz import plot_spikes_hist, plot_spikes_raster, plot_cells


def read_spikes(fname, gid_dict=None):
def read_spikes(fname, gid_dict=None, tstart=None, tstop=None):
"""Read spiking activity from a collection of spike trial files.
Parameters
Expand All @@ -25,6 +25,8 @@ def read_spikes(fname, gid_dict=None):
containing the range of Cell or input IDs of different
cell or input types. If None, each spike file must contain
a 3rd column for spike type.
tstart, tstop : float | None
Values defining the start and stop times of all trials.
Returns
----------
Expand All @@ -35,27 +37,54 @@ def read_spikes(fname, gid_dict=None):
spike_times = []
spike_gids = []
spike_types = []
spike_tstart = []
spike_tstop = []
for file in sorted(glob(fname)):
spike_trial = np.loadtxt(file, dtype=str)
spike_times += [list(spike_trial[:, 0].astype(float))]
spike_gids += [list(spike_trial[:, 1].astype(int))]

# Note that legacy HNN 'spk.txt' files don't contain a 3rd column for
# spike type. If reading a legacy version, validate that a gid_dict is
# provided.
if spike_trial.shape[1] == 3:
# spike type, or a 4th and 5th column for tstart and tstop. If reading
# a legacy version, validate that gid_dict, tstart, and tstop provided.
if spike_trial.shape[1] >= 3:
spike_types += [list(spike_trial[:, 2].astype(str))]
else:
elif spike_trial.shape[1] == 2:
if gid_dict is None:
raise ValueError("gid_dict must be provided if spike types "
"are unspecified in the file %s" % (file,))
spike_types += [[]]

spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types)
if spike_trial.shape[1] == 5:
spike_tstart += np.unique(spike_trial[:, 3].astype(float)).tolist()
spike_tstop += np.unique(spike_trial[:, 4].astype(float)).tolist()
else:
if tstart is None or tstop is None:
raise ValueError("tstart and tstop must be provided if values "
"are unspecified in the file %s" % (file,))

if len(np.unique(spike_tstart)) > 1 or len(np.unique(spike_tstop)) > 1:
raise ValueError("tstart and tstop must match across files.")

elif len(np.unique(spike_tstart)) == 1 and len(
np.unique(spike_tstop)) == 1:
spike_tstart = spike_tstart[0]
spike_tstop = spike_tstop[0]
else:
spike_tstart = np.min(sum(spike_times, []))
spike_tstop = np.max(sum(spike_times, []))

spikes = Spikes(times=spike_times, gids=spike_gids, types=spike_types,
tstart=spike_tstart, tstop=spike_tstop)

if gid_dict is not None:
spikes.update_types(gid_dict)

return Spikes(times=spike_times, gids=spike_gids, types=spike_types)
if tstart is not None and tstop is not None:
spikes.update_trial_bounds(spike_tstart, spike_tstop)

return Spikes(times=spike_times, gids=spike_gids, types=spike_types,
tstart=spike_tstart, tstop=spike_tstop)


def _create_coords(n_pyr_x, n_pyr_y, n_common_feeds, p_unique_keys,
Expand Down Expand Up @@ -330,6 +359,8 @@ class Spikes(object):
The inner list contains the type of spike (e.g., evprox1
or L2_pyramidal) that occured at the corresonding time stamp.
Each gid corresponds to a type via Network().gid_dict.
tstart, tstop : int | float | None
Values defining the start and stop times of all trials.
Attributes
----------
Expand All @@ -345,6 +376,8 @@ class Spikes(object):
The inner list contains the type of spike (e.g., evprox1
or L2_pyramidal) that occured at the corresonding time stamp.
Each gid corresponds to a type via Network::gid_dict.
tstart, tstop : int | float | None
Values defining the start and stop times of all trials.
Methods
-------
Expand All @@ -353,11 +386,15 @@ class Spikes(object):
plot(ax=None, show=True)
Plot and return a matplotlib Figure object showing the
aggregate network spiking activity according to cell type.
mean_rates(mean_type='all')
Calculate mean firing rate for each cell type. Specify
averaging method with mean_type argument.
write(fname)
Write spiking activity to a collection of spike trial files.
"""

def __init__(self, times=None, gids=None, types=None):
def __init__(self, times=None, gids=None, types=None,
tstart=None, tstop=None):
if times is None:
times = list()
if gids is None:
Expand Down Expand Up @@ -387,6 +424,8 @@ def __init__(self, times=None, gids=None, types=None):
self._times = times
self._gids = gids
self._types = types
self._tstart = tstart
self._tstop = tstop

def __repr__(self):
class_name = self.__class__.__name__
Expand All @@ -403,7 +442,9 @@ def __eq__(self, other):
for trial in other._times]
return (times_self == times_other and
self._gids == other._gids and
self._types == other._types)
self._types == other._types and
self._tstart == other._tstart and
self._tstop == other._tstop)

@property
def times(self):
Expand All @@ -417,6 +458,14 @@ def gids(self):
def types(self):
return self._types

@property
def tstart(self):
return self._tstart

@property
def tstop(self):
return self._tstop

def update_types(self, gid_dict):
"""Update spike types in the current instance of Spikes.
Expand Down Expand Up @@ -448,6 +497,25 @@ def update_types(self, gid_dict):
spike_types += [list(spike_types_trial)]
self._types = spike_types

def update_trial_bounds(self, tstart, tstop):
"""Update tstart and tstop in the current instance of Spikes.
Parameters
----------
tstart, tstop : float | None
Values defining the start and stop times of all trials.
"""

# Validate tstart, tstop
if not isinstance(tstart, (int, float)) or not isinstance(
tstop, (int, float)):
raise ValueError('tstart and tstop must be of type int or float')
elif tstop <= tstart:
raise ValueError('tstop must be greater than tstart')
self._tstart = tstart
self._tstop = tstop

def mean_rates(self, mean_type='all'):
"""Mean spike rates (Hz) by cell type.
Expand All @@ -468,10 +536,8 @@ def mean_rates(self, mean_type='all'):
"""
cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket']
spike_rates = dict()
all_spike_times = np.array(sum(self._times, []))
all_spike_types = np.array(sum(self._types, []))
all_spike_gids = np.array(sum(self._gids, []))
tstart, tstop = min(all_spike_times), max(all_spike_times)

if mean_type not in ['all', 'trial', 'cell']:
raise ValueError("Invalid mean_type. Valid arguments include "
Expand All @@ -489,7 +555,7 @@ def mean_rates(self, mean_type='all'):
spike_gids)[trial_type_mask], return_counts=True)

gid_spike_rate[trial, cell_type_gids == gid] = (gid_counts / (
tstop - tstart)) * 1000
self._tstop - self._tstart)) * 1000

if mean_type == 'all':
spike_rates[cell_type] = np.mean(
Expand Down Expand Up @@ -567,12 +633,15 @@ def write(self, fname):
1) spike time (s),
2) spike gid, and
3) gid type
4) tstart
5) tstop
"""

for trial_idx in range(len(self._times)):
with open(fname % (trial_idx,), 'w') as f:
for spike_idx in range(len(self._times[trial_idx])):
f.write('{:.3f}\t{}\t{}\n'.format(
f.write('{:.3f}\t{}\t{}\t{:.3f}\t{:.3f}\n'.format(
self._times[trial_idx][spike_idx],
int(self._gids[trial_idx][spike_idx]),
self._types[trial_idx][spike_idx]))
self._types[trial_idx][spike_idx],
float(self._tstart), float(self._tstop)))
2 changes: 2 additions & 0 deletions hnn_core/parallel_backends.py
Expand Up @@ -23,6 +23,8 @@ def _gather_trial_data(sim_data, net, n_trials):
"""
dpls = []

net.spikes.update_trial_bounds(0.0, net.params['tstop'])

for idx in range(n_trials):
dpls.append(sim_data[idx][0])
spikedata = sim_data[idx][1]
Expand Down
36 changes: 31 additions & 5 deletions hnn_core/tests/test_network.py
Expand Up @@ -62,23 +62,37 @@ def test_spikes():
spiketimes = [[2.3456, 7.89], [4.2812, 93.2]]
spikegids = [[1, 3], [5, 7]]
spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']]
spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes)
spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes,
tstart=0.1, tstop=98.4)
spikes.plot_hist(show=False)
spikes.write('/tmp/spk_%d.txt')
assert spikes == read_spikes('/tmp/spk_*.txt')
assert ("Spikes | 2 simulation trials" in repr(spikes))

with pytest.raises(TypeError, match="times should be a list of lists"):
spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids,
types=spiketypes)
types=spiketypes, tstart=0.1, tstop=98.4)

with pytest.raises(TypeError, match="times should be a list of lists"):
spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes)
spikes = Spikes(times=[1, 2], gids=spikegids, types=spiketypes,
tstart=0.1, tstop=98.4)

with pytest.raises(ValueError, match="times, gids, and types should be "
"lists of the same length"):
spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids,
types=spiketypes)
types=spiketypes, tstart=0.1, tstop=98.4)

with pytest.raises(ValueError, match="tstart and tstop must be of type "
"int or float"):
spikes = Spikes()
spikes.update_trial_bounds(tstart=0.1, tstop='ABC')

with pytest.raises(ValueError, match="tstop must be greater than tstart"):
spikes = Spikes()
spikes.update_trial_bounds(tstart=0.1, tstop=-1.0)

spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes,
tstart=0.1, tstop=98.4)

with pytest.raises(TypeError, match="spike_types should be str, "
"list, dict, or None"):
Expand Down Expand Up @@ -107,9 +121,21 @@ def test_spikes():
np.savetxt(fname, times_gids_only, delimiter='\t', fmt='%s')
with pytest.raises(ValueError, match="gid_dict must be provided if spike "
"types are unspecified in the file /tmp/spk_0.txt"):
spikes = read_spikes('/tmp/spk_*.txt')
spikes = read_spikes('/tmp/spk_*.txt', tstart=0.1, tstop=98.4)
with pytest.raises(ValueError, match="gid_dict should contain only "
"disjoint sets of gid values"):
gid_dict = {'L2_pyramidal': range(3), 'L2_basket': range(2, 4),
'L5_pyramidal': range(4, 6), 'L5_basket': range(6, 8)}
spikes = read_spikes('/tmp/spk_*.txt', gid_dict=gid_dict,
tstart=0.1, tstop=98.4)

# Write spike file with no 'tstart' or 'tstop' columns
# Check for gid_dict errors
for fname in sorted(glob('/tmp/spk_*.txt')):
times_gids_types_only = np.loadtxt(fname, dtype=str)[:, (0, 1)]
np.savetxt(fname, times_gids_types_only, delimiter='\t', fmt='%s')
with pytest.raises(ValueError, match="tstart and tstop must be provided "
"if values are unspecified in the file /tmp/spk_0.txt"):
gid_dict = {'L2_pyramidal': 1, 'L2_basket': 3,
'L5_pyramidal': 5, 'L5_basket': 7}
spikes = read_spikes('/tmp/spk_*.txt', gid_dict=gid_dict)

0 comments on commit b2ac4fe

Please sign in to comment.