Skip to content

Commit

Permalink
Rename/reorganize to compute spike rates for individual gids, add opt…
Browse files Browse the repository at this point in the history
…ion to select averaging performed
  • Loading branch information
ntolley committed Sep 4, 2020
1 parent 008c05e commit a930b8e
Showing 1 changed file with 38 additions and 8 deletions.
46 changes: 38 additions & 8 deletions hnn_core/network.py
Expand Up @@ -448,8 +448,18 @@ def update_types(self, gid_dict):
spike_types += [list(spike_types_trial)]
self._types = spike_types

def get_spike_rates(self):
"""Spike rates by cell type.
def mean_rates(self, mean_type='all'):
"""Mean spike rates (Hz) by cell type.
Parameters
----------
mean_type : str
'all' : Average over trials and cells
Returns mean rate for cell types
'trial' : Average over cell types
Returns trial mean rate for cell types
'cell' : Average over individual cells
Returns trial mean rate for individual cells
Returns
-------
Expand All @@ -459,14 +469,34 @@ def get_spike_rates(self):
cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket']
spike_rates = dict()
all_spike_times = np.array(sum(self._times, []))
all_spike_types = np.array(sum(self._types, []))
all_spike_gids = np.array(sum(self._gids, []))
tstart, tstop = min(all_spike_times), max(all_spike_times)

for cell_type in cell_types:
trial_spike_rate = list()
for spike_times, spike_types in zip(self._times, self._types):
spike_times_cells = spike_times[spike_types == cell_type]
trial_spike_rate.append(
len(spike_times_cells) / (tstop - tstart))
spike_rates[cell_type] = np.mean(trial_spike_rate)
type_mask = np.in1d(all_spike_types, cell_type)
cell_type_gids = np.unique(all_spike_gids[type_mask])
gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids)))

trial_data = zip(self._types, self._gids)
for trial, (spike_types, spike_gids) in enumerate(trial_data):
trial_type_mask = np.in1d(spike_types, cell_type)
gid, gid_counts = np.unique(np.array(
spike_gids)[trial_type_mask], return_counts=True)

gid_spike_rate[trial, cell_type_gids == gid] = (gid_counts / (
tstop - tstart)) * 1000

if mean_type == 'all':
spike_rates[cell_type] = np.mean(
np.mean(gid_spike_rate, axis=1))
if mean_type == 'trial':
spike_rates[cell_type] = np.mean(
gid_spike_rate, axis=1).tolist()
if mean_type == 'cell':
spike_rates[cell_type] = [gid_trial_rate.tolist()
for gid_trial_rate in gid_spike_rate]

return spike_rates

def plot(self, ax=None, show=True):
Expand Down

0 comments on commit a930b8e

Please sign in to comment.