Skip to content

Commit

Permalink
Rebase fix
Browse files Browse the repository at this point in the history
Co-authored-by: mjpelah <mjpelah@gmail.com>
  • Loading branch information
ntolley and mjpelah committed Jan 15, 2023
1 parent 9c799dd commit c64ced3
Showing 1 changed file with 37 additions and 81 deletions.
118 changes: 37 additions & 81 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def _decimate_plot_data(decim, data, times, sfreq=None):

def plt_show(show=True, fig=None, **kwargs):
"""Show a figure while suppressing warnings.
NB copied from :func:`mne.viz.utils.plt_show`.
Parameters
----------
show : bool
Expand All @@ -82,7 +80,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
ax=None, decim=None, color='cividis',
voltage_offset=50, voltage_scalebar=200, show=True):
"""Plot laminar extracellular electrode array voltage time series.
Parameters
----------
times : array-like, shape (n_times,)
Expand Down Expand Up @@ -117,7 +114,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
:func:`~matplotlib.axes.Axes.set_yticklabels`.
show : bool
If True, show the figure
Returns
-------
fig : instance of plt.fig
Expand Down Expand Up @@ -223,7 +219,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
color='k', label="average", average=False, show=True):
"""Simple layer-specific plot function.
Parameters
----------
dpl : instance of Dipole | list of Dipole instances
Expand All @@ -250,7 +245,6 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
If True, render the average across all dpls.
show : bool
If True, show the figure
Returns
-------
fig : instance of plt.fig
Expand Down Expand Up @@ -322,7 +316,6 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
show=True):
"""Plot the histogram of spiking activity across trials.
Parameters
----------
cell_response : instance of CellResponse
Expand All @@ -334,25 +327,17 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
a new figure is created.
spike_types: string | list | dictionary | None
String input of a valid spike type is plotted individually.
| Ex: ``'poisson'``, ``'evdist'``, ``'evprox'``, ...
List of valid string inputs will plot each spike type individually.
| Ex: ``['poisson', 'evdist']``
Dictionary of valid lists will plot list elements as a group.
| Ex: ``{'Evoked': ['evdist', 'evprox'], 'Tonic': ['poisson']}``
If None, all input spike types are plotted individually if any
are present. Otherwise spikes from all cells are plotted.
Valid strings also include leading characters of spike types
| Ex: ``'ev'`` is equivalent to ``['evdist', 'evprox']``
show : bool
If True, show the figure.
Returns
-------
fig : instance of matplotlib Figure
Expand Down Expand Up @@ -444,7 +429,6 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,

def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
"""Plot the aggregate spiking activity according to cell type.
Parameters
----------
cell_response : instance of CellResponse
Expand All @@ -455,7 +439,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
An axis object from matplotlib. If None, a new figure is created.
show : bool
If True, show the figure.
Returns
-------
fig : instance of matplotlib Figure
Expand Down Expand Up @@ -520,7 +503,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):

def plot_cells(net, ax=None, show=True):
"""Plot the cells using Network.pos_dict.
Parameters
----------
net : instance of Network
Expand All @@ -530,7 +512,6 @@ def plot_cells(net, ax=None, show=True):
a new figure is created.
show : bool
If True, show the figure.
Returns
-------
fig : instance of matplotlib Figure
Expand Down Expand Up @@ -577,7 +558,6 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None,
colormap='inferno', colorbar=True, colorbar_inside=False,
show=True):
"""Plot Morlet time-frequency representation of dipole time course
Parameters
----------
dpl : instance of Dipole | list of Dipole instances
Expand Down Expand Up @@ -612,7 +592,6 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None,
Put the color inside the heatmap if True.
show : bool
If True, show the figure
Returns
-------
fig : instance of matplotlib Figure
Expand Down Expand Up @@ -716,13 +695,11 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None,
def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg',
color=None, label=None, ax=None, show=True):
"""Plot power spectral density (PSD) of dipole time course
Applies `~scipy.signal.periodogram` from SciPy with ``window='hamming'``.
Note that no spectral averaging is applied across time, as most
``hnn_core`` simulations are short-duration. However, passing a list of
`Dipole` instances will plot their average (Hamming-windowed) power, which
resembles the `Welch`-method applied over time.
Parameters
----------
dpl : instance of Dipole | list of Dipole instances
Expand All @@ -745,7 +722,6 @@ def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg',
The matplotlib axis.
show : bool
If True, show the figure
Returns
-------
fig : instance of matplotlib Figure
Expand Down Expand Up @@ -809,7 +785,6 @@ def _linewidth_from_data_units(ax, linewidth):

def plot_cell_morphology(cell, ax, show=True):
"""Plot the cell morphology.
Parameters
----------
cell : instance of Cell
Expand All @@ -818,7 +793,6 @@ def plot_cell_morphology(cell, ax, show=True):
Matplotlib 3D axis
show : bool
If True, show the plot
Returns
-------
axes : list of instance of Axes3D
Expand All @@ -828,7 +802,7 @@ def plot_cell_morphology(cell, ax, show=True):
from mpl_toolkits.mplot3d import Axes3D # noqa
cell_list = list()
colors = ['b', 'c', 'r', 'm']
clr_index=0
clr_index = 0

if ax is None:
plt.figure()
Expand All @@ -838,11 +812,11 @@ def plot_cell_morphology(cell, ax, show=True):
for ind_cell in cell:
cell_list = list(cell.values())
else:
cell_list[0]=cell
cell_list[0] = cell

# Cell is in XZ plane
#ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150))
#ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200))
# ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150))
# ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200))
cell_radii = list()
cell_radii.append(clr_index)
for clr_index, cell in enumerate(cell_list):
Expand Down Expand Up @@ -871,7 +845,7 @@ def plot_cell_morphology(cell, ax, show=True):
dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0]
dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1]
dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2]
xs.append(pt[0] + dx + (radius + cell_radii[-1]+100))
xs.append(pt[0] + dx + (radius + cell_radii[-1] + 100))
ys.append(pt[1] + dz)
zs.append(pt[2] + dy)
ax.plot(xs, ys, zs, color=colors[clr_index], linewidth=linewidth)
Expand All @@ -887,7 +861,6 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True,
colorbar=True, colormap='Greys',
show=True):
"""Plot connectivity matrix with color bar for synaptic weights
Parameters
----------
net : Instance of Network object
Expand All @@ -906,7 +879,6 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True,
If True (default), adjust figure to include colorbar.
show : bool
If True, show the plot
Returns
-------
fig : instance of matplotlib Figure
Expand Down Expand Up @@ -1017,12 +989,10 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos,
def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None,
colorbar=True, colormap='viridis', show=True):
"""Plot synaptic weight of connections.
This is an interactive plot with source cells shown in the left
subplot and connectivity from a source cell to all the target cells
in the right subplot. Click on the cells in the left subplot to
explore how the connectivity pattern changes for different source cells.
Parameters
----------
net : Instance of Network object
Expand All @@ -1041,12 +1011,10 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None,
If True (default), adjust figure to include colorbar.
show : bool
If True, show the plot
Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
Notes
-----
Target cells will be determined by the connections in
Expand All @@ -1055,7 +1023,6 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None,
it will appear as a smaller black cross.
Source cell is plotted as a red square. Source cell will not be plotted if
the connection corresponds to a drive, ex: poisson, bursty, etc.
"""
import matplotlib.pyplot as plt
from .network import Network
Expand Down Expand Up @@ -1150,55 +1117,44 @@ def _onclick(event):
return ax.get_figure()


def _plot_cell(ax, cell_type=None, show=True):
"""Plot the cell morphology of a specific cell type
parameters
----------
cell_type : instance of net.cell_type[]
The type of cell to be plotted. If None,
generic cell type
ax : instance of Axes3D
Matplotlib 3D axis
show : bool
if True, show the plot
"""

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

if ax is none:
plt.figure()
ax = plt.axes(projection='3d')

return ax


def plot_cell_morphologies(net, ax=None, show=true):
"""Plot the morphology of the network cells
def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,
show=True):
"""Plot laminar current source density (CSD) estimation from LFP array.
Parameters
----------
net : instance of Network
The network object
ax : instance of matplotlib Axes3D | None
An axis object from matplotlib. If none,
a new figure is created.
Show : bool
If True, show the figure
times : Numpy array, shape (n_times,)
Sampling times (in ms).
data : array-like, shape (n_channels, n_times)
CSD data, channels x time.
ax : instance of matplotlib figure | None
The matplotlib axis.
colorbar : bool
If the colorbar is presented.
contact_labels : list
Labels associated with the contacts to plot. Passed as-is to
:func:`~matplotlib.axes.Axes.set_yticklabels`.
show : bool
If True, show the plot.
Returns
-------
fig : instance of matplotlib figure
The matplotlib figure handle
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
_, ax = plt.subplots(1, 1, constrained_layout=True)

im = ax.pcolormesh(times, contact_labels, np.array(data),
cmap="jet_r", shading='auto')
ax.set_title("CSD")

if colorbar:
color_axis = ax.inset_axes([1.05, 0, 0.02, 1], transform=ax.transAxes)
plt.colorbar(im, ax=ax, cax=color_axis).set_label(r'$CSD (uV/um^{2})$')

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Electrode depth')
plt.tight_layout()
plt_show(show)

return ax.get_figure()

0 comments on commit c64ced3

Please sign in to comment.