diff --git a/hnn_core/network.py b/hnn_core/network.py index 859de8563..c1e8f785a 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -448,8 +448,18 @@ def update_types(self, gid_dict): spike_types += [list(spike_types_trial)] self._types = spike_types - def get_spike_rates(self): - """Spike rates by cell type. + def mean_rates(self, mean_type='all'): + """Mean spike rates (Hz) by cell type. + + Parameters + ---------- + mean_type : str + 'all' : Average over trials and cells + Returns mean rate for cell types + 'trial' : Average over cell types + Returns trial mean rate for cell types + 'cell' : Average over individual cells + Returns trial mean rate for individual cells Returns ------- @@ -459,14 +469,34 @@ def get_spike_rates(self): cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] spike_rates = dict() all_spike_times = np.array(sum(self._times, [])) + all_spike_types = np.array(sum(self._types, [])) + all_spike_gids = np.array(sum(self._gids, [])) tstart, tstop = min(all_spike_times), max(all_spike_times) + for cell_type in cell_types: - trial_spike_rate = list() - for spike_times, spike_types in zip(self._times, self._types): - spike_times_cells = spike_times[spike_types == cell_type] - trial_spike_rate.append( - len(spike_times_cells) / (tstop - tstart)) - spike_rates[cell_type] = np.mean(trial_spike_rate) + type_mask = np.in1d(all_spike_types, cell_type) + cell_type_gids = np.unique(all_spike_gids[type_mask]) + gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids))) + + trial_data = zip(self._types, self._gids) + for trial, (spike_types, spike_gids) in enumerate(trial_data): + trial_type_mask = np.in1d(spike_types, cell_type) + gid, gid_counts = np.unique(np.array( + spike_gids)[trial_type_mask], return_counts=True) + + gid_spike_rate[trial, cell_type_gids == gid] = (gid_counts / ( + tstop - tstart)) * 1000 + + if mean_type == 'all': + spike_rates[cell_type] = np.mean( + np.mean(gid_spike_rate, axis=1)) + if mean_type == 'trial': + spike_rates[cell_type] = np.mean( + gid_spike_rate, axis=1).tolist() + if mean_type == 'cell': + spike_rates[cell_type] = [gid_trial_rate.tolist() + for gid_trial_rate in gid_spike_rate] + return spike_rates def plot(self, ax=None, show=True):