Skip to content

Commit

Permalink
Merge pull request #152 from neurodsp-tools/plts
Browse files Browse the repository at this point in the history
Add Plotting Funcs
  • Loading branch information
TomDonoghue committed Apr 12, 2019
2 parents ab1f9d8 + e44fe87 commit 083f3fd
Show file tree
Hide file tree
Showing 20 changed files with 637 additions and 269 deletions.
1 change: 1 addition & 0 deletions doc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ clean:
## make install:
# Build the html site, and push it to gh-pages branch of repo to deploy
install:
make clean
rm -rf _build/doctrees _build/tmp_html
# Clone, specifically, the gh-pages branch:
# --no-checkout just fetches the root folder without content
Expand Down
54 changes: 54 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,57 @@ Combined Signals
:toctree: generated/

sim_combined

Plots
-----

Functions for plotting time series and analysis outputs.

Time Series
~~~~~~~~~~~

.. currentmodule:: neurodsp.plts.time_series

.. autosummary::
:toctree: generated/

plot_time_series
plot_instantaneous_measure
plot_bursts

Spectral
~~~~~~~~

.. currentmodule:: neurodsp.plts.spectral

.. autosummary::
:toctree: generated/

plot_power_spectra
plot_scv
plot_scv_rs_lines
plot_scv_rs_matrix
plot_spectral_hist

Filter
~~~~~~

.. currentmodule:: neurodsp.plts.filt

.. autosummary::
:toctree: generated/

plot_filter_properties
plot_frequency_response
plot_impulse_response

Rhythm
~~~~~~

.. currentmodule:: neurodsp.plts.rhythm

.. autosummary::
:toctree: generated/

plot_swm_pattern
plot_lagged_coherence
22 changes: 8 additions & 14 deletions examples/plot_mne_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from neurodsp.burst import detect_bursts_dual_threshold
from neurodsp.rhythm import lagged_coherence

# Import NeuroDSP plotting functions
from neurodsp.plts import plot_time_series, plot_power_spectra, plot_bursts, plot_lagged_coherence

###################################################################################################
# Load & Check MNE Data
# ---------------------
Expand Down Expand Up @@ -87,8 +90,7 @@
###################################################################################################

# Plot a segment of the extracted time series data
plt.figure(figsize=(16, 3))
plt.plot(times, sig, 'k')
plot_time_series(times, sig)

###################################################################################################
# Calculate Power Spectra
Expand All @@ -112,9 +114,8 @@

###################################################################################################

# Plot the power spectra
plt.figure(figsize=(8, 8))
plt.semilogy(freqs, powers)
# Plot the power spectra, and note the peak power
plot_power_spectra(freqs, powers)
plt.plot(freqs[np.argmax(powers)], np.max(powers), '.r', ms=12)

###################################################################################################
Expand All @@ -139,10 +140,7 @@
###################################################################################################

# Plot original signal and burst activity
plt.figure(figsize=(16, 3))
plt.plot(times, sig, 'k', label='Raw Data')
plt.plot(times[bursting], sig[bursting], 'r', label='Detected Bursts')
plt.legend(loc='best')
plot_bursts(times, sig, bursting, labels=['Raw Data', 'Detected Bursts'])

###################################################################################################
# Measure Rhythmicity with Lagged Coherence
Expand Down Expand Up @@ -174,11 +172,7 @@
###################################################################################################

# Visualize lagged coherence across all frequencies
plt.figure(figsize=(6, 3))
plt.plot(freqs, lcs, 'k.-')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Lagged coherence')
plt.tight_layout()
plot_lagged_coherence(freqs, lcs)

###################################################################################################

Expand Down
5 changes: 5 additions & 0 deletions neurodsp/plts/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
"""Plotting functions."""

from .time_series import plot_time_series, plot_bursts
from .filt import plot_filter_properties, plot_frequency_response, plot_impulse_response
from .rhythm import plot_swm_pattern, plot_lagged_coherence
from .spectral import plot_power_spectra, plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix, plot_spectral_hist
14 changes: 10 additions & 4 deletions neurodsp/plts/filt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig

###################################################################################################
###################################################################################################

@savefig
def plot_filter_properties(f_db, db, fs, impulse_response):
"""Plot filter properties, including frequency response and filter kernel.
Expand All @@ -25,6 +29,8 @@ def plot_filter_properties(f_db, db, fs, impulse_response):
plot_impulse_response(fs, impulse_response, ax=ax[1])


@savefig
@style_plot
def plot_frequency_response(f_db, db, ax=None):
"""Plot the frequency response of a filter.
Expand All @@ -38,8 +44,7 @@ def plot_frequency_response(f_db, db, ax=None):
Figure axes upon which to plot.
"""

if not ax:
_, ax = plt.subplots(figsize=(5, 5))
ax = check_ax(ax, (5, 5))

ax.plot(f_db, db, 'k')

Expand All @@ -48,6 +53,8 @@ def plot_frequency_response(f_db, db, ax=None):
ax.set_ylabel('Attenuation (dB)')


@savefig
@style_plot
def plot_impulse_response(fs, impulse_response, ax=None):
"""Plot the impulse response of a filter.
Expand All @@ -59,8 +66,7 @@ def plot_impulse_response(fs, impulse_response, ax=None):
Figure axes upon which to plot.
"""

if not ax:
_, ax = plt.subplots(figsize=(5, 5))
ax = check_ax(ax, (5, 5))

# Create a samples vector, center to zero, and convert to time
samples = np.arange(len(impulse_response))
Expand Down
53 changes: 53 additions & 0 deletions neurodsp/plts/rhythm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Plotting functions for neurodsp.rhythm."""

import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig

###################################################################################################
###################################################################################################

@savefig
@style_plot
def plot_swm_pattern(pattern, ax=None):
"""Plot the resulting pattern from a sliding window matching analysis.
Parameters
----------
pattern : 1d array
The resulting average pattern from applying sliding window matching.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
"""

ax = check_ax(ax, (4, 4))

plt.plot(pattern, 'k')

plt.title('Average Pattern')
plt.xlabel('Time (samples)')
plt.ylabel('Voltage (a.u.)')


@savefig
@style_plot
def plot_lagged_coherence(freqs, lcs, ax=None):
"""Plot lagged coherence values across frequencies.
Parameters
----------
freqs : 1d array
Vector of frequencies at which lagged coherence was computed.
lcs : 1d array
Lagged coherence values across the computed frequencies.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
"""

ax = check_ax(ax, (6, 3))

plt.plot(freqs, lcs, 'k.-')

plt.xlabel('Frequency (Hz)')
plt.ylabel('Lagged Coherence')
125 changes: 123 additions & 2 deletions neurodsp/plts/spectral.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,130 @@
"""Plotting functions for neurodsp.spectral."""

from itertools import repeat

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig

###################################################################################################
###################################################################################################

@savefig
@style_plot
def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None):
"""Plot power spectra.
Parameters
----------
freqs : 1d array or list of 1d array
Frequency vector.
powers : 1d array or list of 1d array
Power values.
labels : str or list of str, optional
Labels for each time series.
colors : str or list of str
Colors to use to plot lines.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
"""

ax = check_ax(ax, (6, 6))

freqs = repeat(freqs) if isinstance(freqs, np.ndarray) else freqs
powers = [powers] if isinstance(powers, np.ndarray) else powers

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)

if colors is not None:
colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for freq, power, label in zip(freqs, powers, labels):
ax.loglog(freq, power, label=label)

ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power (V^2/Hz)')


@savefig
@style_plot
def plot_scv(freqs, scv, ax=None):
"""Plot the SCV.
Parameters
----------
freqs : 1d array
Frequency vector.
scv : 1d array
Spectral coefficient of variation.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
"""

ax = check_ax(ax, (5, 5))

ax.loglog(freqs, scv)

ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('SCV')


@savefig
@style_plot
def plot_scv_rs_lines(freqs, scv_rs, ax=None):
"""Plot the SCV, from the resampling method.
Parameters
----------
freqs : 1d array
Frequency vector.
scv_rs :
Spectral coefficient of variation, from resampling procedure.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
"""

ax = check_ax(ax, (8, 8))

ax.loglog(freqs, scv_rs, 'k', alpha=0.1)
ax.loglog(freqs, np.mean(scv_rs, axis=1), lw=2)
ax.loglog(freqs, len(freqs)*[1.])

ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('SCV')


@savefig
@style_plot
def plot_scv_rs_matrix(freqs, t_inds, scv_rs):
"""Plot the SCV, from the resampling method.
Parameters
----------
freqs : 1d array
Frequency vector.
t_inds : 1d array
Time indices
scv_rs : 1d array
Spectral coefficient of variation, from resampling procedure.
"""

fig, ax = plt.subplots(figsize=(10, 5))

plt.imshow(np.log10(scv_rs), aspect='auto',
extent=(t_inds[0], t_inds[-1], freqs[-1], freqs[0]))
plt.colorbar(label='SCV')

plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')


@savefig
@style_plot
def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None, spectrum=None):
"""Plot the spectral histogram.
Expand All @@ -27,12 +146,14 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None, sp
plt.figure(figsize=(8, 12 * len(power_bins) / len(freqs)))

# Plot histogram intensity as image and automatically adjust aspect ratio
plt.imshow(spectral_hist, extent=[freqs[0], freqs[-1],
power_bins[0], power_bins[-1]], aspect='auto')
plt.imshow(spectral_hist, extent=[freqs[0], freqs[-1], power_bins[0], power_bins[-1]], aspect='auto')
plt.xlabel('Frequency (Hz)', fontsize=15)
plt.ylabel('Log10 Power', fontsize=15)
plt.colorbar(label='Probability')

plt.xlabel('Frequency (Hz)')
plt.ylabel('Log10 Power')

# If a PSD is provided, plot over the histogram data
if spectrum is not None:
plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1])
Expand Down
Loading

0 comments on commit 083f3fd

Please sign in to comment.