Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add ability to index individual trials when plotting rasters/histograms #472

Merged
merged 8 commits into from Mar 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -31,6 +31,9 @@ Bug
- Evoked drive optimization no longer assigns a default timing sigma value to
a drive if it is not already specified, by `Ryan Thorpe`_ in :gh:`446`.

- Subsets of trials can be indexed when using :func:`~hnn_core.viz.plot_spikes_raster`
and :func:`~hnn_core.viz.plot_spikes_hist`, by `Nick Tolley`_ in :gh:`472`.

API
~~~
- Optimization of the evoked drives can be conducted on any :class:`~hnn_core.Network`
Expand Down
22 changes: 13 additions & 9 deletions hnn_core/cell_response.py
Expand Up @@ -322,14 +322,15 @@ def mean_rates(self, tstart, tstop, gid_ranges, mean_type='all'):

return spike_rates

def plot_spikes_raster(self, ax=None, show=True):
def plot_spikes_raster(self, trial_idx=None, ax=None, show=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an API breaking change. It will break code which does:

plot_spikes_raster(ax)

as a general rule, you want to put new arguments towards the end, but maybe before show so API is consistent. Anyhow, I don't think we are doing deprecations at this point, so I'll let this pass but something to remember for the future :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, might be worth discussing the mechanics of deprecation warnings at some point. Especially identifying functions that are the most likely to change in the near future

"""Plot the aggregate spiking activity according to cell type.

Parameters
----------
trial_idx : int | list of int | None
Index of trials to be plotted. If None, all trials plotted.
ax : instance of matplotlib axis | None
An axis object from matplotlib. If None,
a new figure is created.
An axis object from matplotlib. If None, a new figure is created.
show : bool
If True, show the figure.

Expand All @@ -338,16 +339,19 @@ def plot_spikes_raster(self, ax=None, show=True):
fig : instance of matplotlib Figure
The matplotlib figure object.
"""
return plot_spikes_raster(cell_response=self, ax=ax, show=show)
return plot_spikes_raster(
cell_response=self, trial_idx=trial_idx, ax=ax, show=show)

def plot_spikes_hist(self, ax=None, spike_types=None, show=True):
def plot_spikes_hist(self, trial_idx=None, ax=None, spike_types=None,
show=True):
"""Plot the histogram of spiking activity across trials.

Parameters
----------
trial_idx : int | list of int | None
Index of trials to be plotted. If None, all trials plotted.
ax : instance of matplotlib axis | None
An axis object from matplotlib. If None,
a new figure is created.
An axis object from matplotlib. If None, a new figure is created.
spike_types: string | list | dictionary | None
String input of a valid spike type is plotted individually.
Ex: 'common', 'evdist', 'evprox', 'extgauss', 'extpois'
Expand All @@ -367,8 +371,8 @@ def plot_spikes_hist(self, ax=None, spike_types=None, show=True):
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""
return plot_spikes_hist(
self, ax=ax, spike_types=spike_types, show=show)
return plot_spikes_hist(self, trial_idx=trial_idx, ax=ax,
spike_types=spike_types, show=show)

def write(self, fname):
"""Write spiking activity per trial to a collection of files.
Expand Down
22 changes: 17 additions & 5 deletions hnn_core/tests/test_viz.py
Expand Up @@ -66,7 +66,7 @@ def test_network_visualization():
'L5_basket', 'soma',
'ampa', 0.00025, 1.0, lamtha=3.0,
probability=0.8)
fig = plot_cell_connectivity(net, conn_idx)
fig = plot_cell_connectivity(net, conn_idx, show=False)
ax_src, ax_target, _ = fig.axes

pos = net.pos_dict['L2_pyramidal'][2]
Expand Down Expand Up @@ -99,17 +99,18 @@ def test_dipole_visualization():
dpls[0].copy().savgol_filter(h_freq=30).plot(ax=axes) # on top

# test decimation options
plot_dipole(dpls[0], decim=2)
plot_dipole(dpls[0], decim=2, show=False)
for dec in [-1, [2, 2.]]:
with pytest.raises(ValueError,
match='each decimation factor must be a positive'):
plot_dipole(dpls[0], decim=dec)
plot_dipole(dpls[0], decim=dec, show=False)

# test plotting multiple dipoles as overlay
fig = plot_dipole(dpls)
fig = plot_dipole(dpls, show=False)

# multiple TFRs get averaged
fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3)
fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3,
show=False)

with pytest.raises(RuntimeError,
match="All dipoles must be scaled equally!"):
Expand All @@ -122,3 +123,14 @@ def test_dipole_visualization():
dpl_sfreq = dpls[0].copy()
dpl_sfreq.sfreq /= 10
plot_psd([dpls[0], dpl_sfreq])

# test cell response plotting
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)
net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)

with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_hist(trial_idx='blah')
net.cell_response.plot_spikes_hist(trial_idx=0, show=False)
net.cell_response.plot_spikes_hist(trial_idx=[0, 1], show=False)
62 changes: 48 additions & 14 deletions hnn_core/viz.py
Expand Up @@ -284,13 +284,16 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
return ax.get_figure()


def plot_spikes_hist(cell_response, ax=None, spike_types=None, show=True):
def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
show=True):
"""Plot the histogram of spiking activity across trials.

Parameters
----------
cell_response : instance of CellResponse
The CellResponse object from net.cell_response
trial_idx : int | list of int | None
Index of trials to be plotted. If None, all trials plotted.
ax : instance of matplotlib axis | None
An axis object from matplotlib. If None,
a new figure is created.
Expand All @@ -314,8 +317,23 @@ def plot_spikes_hist(cell_response, ax=None, spike_types=None, show=True):
The matplotlib figure handle.
"""
import matplotlib.pyplot as plt
spike_times = np.array(sum(cell_response._spike_times, []))
spike_types_data = np.array(sum(cell_response._spike_types, []))
n_trials = len(cell_response.spike_times)
if trial_idx is None:
trial_idx = list(range(n_trials))

if isinstance(trial_idx, int):
trial_idx = [trial_idx]
_validate_type(trial_idx, list, 'trial_idx', 'int, list of int')

# Extract desired trials
if len(cell_response._spike_times[0]) > 0:
spike_times = np.concatenate(
np.array(cell_response._spike_times)[trial_idx])
spike_types_data = np.concatenate(
np.array(cell_response._spike_types)[trial_idx])
else:
spike_times = np.array([])
spike_types_data = np.array([])

unique_types = np.unique(spike_types_data)
spike_types_mask = {s_type: np.in1d(spike_types_data, s_type)
Expand Down Expand Up @@ -382,16 +400,17 @@ def plot_spikes_hist(cell_response, ax=None, spike_types=None, show=True):
return ax.get_figure()


def plot_spikes_raster(cell_response, ax=None, show=True):
def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
"""Plot the aggregate spiking activity according to cell type.

Parameters
----------
cell_response : instance of CellResponse
The CellResponse object from net.cell_response
trial_idx : int | list of int | None
Index of trials to be plotted. If None, all trials plotted
ax : instance of matplotlib axis | None
An axis object from matplotlib. If None,
a new figure is created.
An axis object from matplotlib. If None, a new figure is created.
show : bool
If True, show the figure.

Expand All @@ -402,9 +421,27 @@ def plot_spikes_raster(cell_response, ax=None, show=True):
"""

import matplotlib.pyplot as plt
spike_times = np.array(sum(cell_response._spike_times, []))
spike_types = np.array(sum(cell_response._spike_types, []))
spike_gids = np.array(sum(cell_response._spike_gids, []))
n_trials = len(cell_response.spike_times)
if trial_idx is None:
trial_idx = list(range(n_trials))

if isinstance(trial_idx, int):
trial_idx = [trial_idx]
ntolley marked this conversation as resolved.
Show resolved Hide resolved
_validate_type(trial_idx, list, 'trial_idx', 'int, list of int')

# Extract desired trials
if len(cell_response._spike_times[0]) > 0:
spike_times = np.concatenate(
np.array(cell_response._spike_times)[trial_idx])
spike_types = np.concatenate(
np.array(cell_response._spike_types)[trial_idx])
spike_gids = np.concatenate(
np.array(cell_response._spike_gids)[trial_idx])
ntolley marked this conversation as resolved.
Show resolved Hide resolved
else:
spike_times = np.array([])
spike_types = np.array([])
spike_gids = np.array([])

cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal']
cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b',
'L2_pyramidal': 'g', 'L2_basket': 'w'}
Expand All @@ -425,12 +462,9 @@ def plot_spikes_raster(cell_response, ax=None, show=True):
if cell_type_times:
cell_type_times = np.concatenate(cell_type_times)
cell_type_ypos = np.concatenate(cell_type_ypos)
else:
cell_type_times = []
cell_type_ypos = []

ax.scatter(cell_type_times, cell_type_ypos, label=cell_type,
color=cell_type_colors[cell_type])
ax.scatter(cell_type_times, cell_type_ypos, label=cell_type,
ntolley marked this conversation as resolved.
Show resolved Hide resolved
color=cell_type_colors[cell_type])

ax.legend(loc=1)
ax.set_facecolor('k')
Expand Down