diff --git a/hnn_core/viz.py b/hnn_core/viz.py index da8c3eca7..caa86d70a 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -60,9 +60,7 @@ def _decimate_plot_data(decim, data, times, sfreq=None): def plt_show(show=True, fig=None, **kwargs): """Show a figure while suppressing warnings. - NB copied from :func:`mne.viz.utils.plt_show`. - Parameters ---------- show : bool @@ -82,7 +80,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, ax=None, decim=None, color='cividis', voltage_offset=50, voltage_scalebar=200, show=True): """Plot laminar extracellular electrode array voltage time series. - Parameters ---------- times : array-like, shape (n_times,) @@ -117,7 +114,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, :func:`~matplotlib.axes.Axes.set_yticklabels`. show : bool If True, show the figure - Returns ------- fig : instance of plt.fig @@ -223,7 +219,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, color='k', label="average", average=False, show=True): """Simple layer-specific plot function. - Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -250,7 +245,6 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, If True, render the average across all dpls. show : bool If True, show the figure - Returns ------- fig : instance of plt.fig @@ -322,7 +316,6 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, show=True): """Plot the histogram of spiking activity across trials. - Parameters ---------- cell_response : instance of CellResponse @@ -334,25 +327,17 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, a new figure is created. spike_types: string | list | dictionary | None String input of a valid spike type is plotted individually. - | Ex: ``'poisson'``, ``'evdist'``, ``'evprox'``, ... - List of valid string inputs will plot each spike type individually. - | Ex: ``['poisson', 'evdist']`` - Dictionary of valid lists will plot list elements as a group. - | Ex: ``{'Evoked': ['evdist', 'evprox'], 'Tonic': ['poisson']}`` - If None, all input spike types are plotted individually if any are present. Otherwise spikes from all cells are plotted. Valid strings also include leading characters of spike types - | Ex: ``'ev'`` is equivalent to ``['evdist', 'evprox']`` show : bool If True, show the figure. - Returns ------- fig : instance of matplotlib Figure @@ -444,7 +429,6 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): """Plot the aggregate spiking activity according to cell type. - Parameters ---------- cell_response : instance of CellResponse @@ -455,7 +439,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): An axis object from matplotlib. If None, a new figure is created. show : bool If True, show the figure. - Returns ------- fig : instance of matplotlib Figure @@ -520,7 +503,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): def plot_cells(net, ax=None, show=True): """Plot the cells using Network.pos_dict. - Parameters ---------- net : instance of Network @@ -530,7 +512,6 @@ def plot_cells(net, ax=None, show=True): a new figure is created. show : bool If True, show the figure. - Returns ------- fig : instance of matplotlib Figure @@ -577,7 +558,6 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, colormap='inferno', colorbar=True, colorbar_inside=False, show=True): """Plot Morlet time-frequency representation of dipole time course - Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -612,7 +592,6 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, Put the color inside the heatmap if True. show : bool If True, show the figure - Returns ------- fig : instance of matplotlib Figure @@ -716,13 +695,11 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', color=None, label=None, ax=None, show=True): """Plot power spectral density (PSD) of dipole time course - Applies `~scipy.signal.periodogram` from SciPy with ``window='hamming'``. Note that no spectral averaging is applied across time, as most ``hnn_core`` simulations are short-duration. However, passing a list of `Dipole` instances will plot their average (Hamming-windowed) power, which resembles the `Welch`-method applied over time. - Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -745,7 +722,6 @@ def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', The matplotlib axis. show : bool If True, show the figure - Returns ------- fig : instance of matplotlib Figure @@ -809,7 +785,6 @@ def _linewidth_from_data_units(ax, linewidth): def plot_cell_morphology(cell, ax, show=True): """Plot the cell morphology. - Parameters ---------- cell : instance of Cell @@ -818,7 +793,6 @@ def plot_cell_morphology(cell, ax, show=True): Matplotlib 3D axis show : bool If True, show the plot - Returns ------- axes : list of instance of Axes3D @@ -828,7 +802,7 @@ def plot_cell_morphology(cell, ax, show=True): from mpl_toolkits.mplot3d import Axes3D # noqa cell_list = list() colors = ['b', 'c', 'r', 'm'] - clr_index=0 + clr_index = 0 if ax is None: plt.figure() @@ -838,11 +812,11 @@ def plot_cell_morphology(cell, ax, show=True): for ind_cell in cell: cell_list = list(cell.values()) else: - cell_list[0]=cell + cell_list[0] = cell # Cell is in XZ plane - #ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150)) - #ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200)) + # ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150)) + # ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200)) cell_radii = list() cell_radii.append(clr_index) for clr_index, cell in enumerate(cell_list): @@ -871,7 +845,7 @@ def plot_cell_morphology(cell, ax, show=True): dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] - xs.append(pt[0] + dx + (radius + cell_radii[-1]+100)) + xs.append(pt[0] + dx + (radius + cell_radii[-1] + 100)) ys.append(pt[1] + dz) zs.append(pt[2] + dy) ax.plot(xs, ys, zs, color=colors[clr_index], linewidth=linewidth) @@ -887,7 +861,6 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, colorbar=True, colormap='Greys', show=True): """Plot connectivity matrix with color bar for synaptic weights - Parameters ---------- net : Instance of Network object @@ -906,7 +879,6 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, If True (default), adjust figure to include colorbar. show : bool If True, show the plot - Returns ------- fig : instance of matplotlib Figure @@ -1017,12 +989,10 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, colorbar=True, colormap='viridis', show=True): """Plot synaptic weight of connections. - This is an interactive plot with source cells shown in the left subplot and connectivity from a source cell to all the target cells in the right subplot. Click on the cells in the left subplot to explore how the connectivity pattern changes for different source cells. - Parameters ---------- net : Instance of Network object @@ -1041,12 +1011,10 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, If True (default), adjust figure to include colorbar. show : bool If True, show the plot - Returns ------- fig : instance of matplotlib Figure The matplotlib figure handle. - Notes ----- Target cells will be determined by the connections in @@ -1055,7 +1023,6 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, it will appear as a smaller black cross. Source cell is plotted as a red square. Source cell will not be plotted if the connection corresponds to a drive, ex: poisson, bursty, etc. - """ import matplotlib.pyplot as plt from .network import Network @@ -1150,55 +1117,44 @@ def _onclick(event): return ax.get_figure() -def _plot_cell(ax, cell_type=None, show=True): - """Plot the cell morphology of a specific cell type - - parameters - ---------- - cell_type : instance of net.cell_type[] - The type of cell to be plotted. If None, - generic cell type - ax : instance of Axes3D - Matplotlib 3D axis - show : bool - if True, show the plot - - """ - - import matplotlib.pyplot as plt - from mpl_toolkits.mplot3d import Axes3D - - if ax is none: - plt.figure() - ax = plt.axes(projection='3d') - - return ax - - -def plot_cell_morphologies(net, ax=None, show=true): - """Plot the morphology of the network cells - +def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, + show=True): + """Plot laminar current source density (CSD) estimation from LFP array. Parameters ---------- - net : instance of Network - The network object - ax : instance of matplotlib Axes3D | None - An axis object from matplotlib. If none, - a new figure is created. - Show : bool - If True, show the figure - + times : Numpy array, shape (n_times,) + Sampling times (in ms). + data : array-like, shape (n_channels, n_times) + CSD data, channels x time. + ax : instance of matplotlib figure | None + The matplotlib axis. + colorbar : bool + If the colorbar is presented. + contact_labels : list + Labels associated with the contacts to plot. Passed as-is to + :func:`~matplotlib.axes.Axes.set_yticklabels`. + show : bool + If True, show the plot. Returns ------- - fig : instance of matplotlib figure - The matplotlib figure handle + fig : instance of matplotlib Figure + The matplotlib figure handle. """ - import matplotlib.pyplot as plt - from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import - if ax is None: - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') + _, ax = plt.subplots(1, 1, constrained_layout=True) + + im = ax.pcolormesh(times, contact_labels, np.array(data), + cmap="jet_r", shading='auto') + ax.set_title("CSD") + + if colorbar: + color_axis = ax.inset_axes([1.05, 0, 0.02, 1], transform=ax.transAxes) + plt.colorbar(im, ax=ax, cax=color_axis).set_label(r'$CSD (uV/um^{2})$') + + ax.set_xlabel('Time (ms)') + ax.set_ylabel('Electrode depth') + plt.tight_layout() + plt_show(show) return ax.get_figure()