Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: Speed up browse figure creation if scalings are provided #10109

Merged
merged 8 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ Enhancements

- Add ``infer_type`` argument to :func:`mne.io.read_raw_edf` and :func:`mne.io.read_raw_bdf` to automatically infer channel types from channel labels (:gh:`10058` by `Clemens Brunner`_)

- Reduce the time it takes to generate a `~mne.io.Raw`, `~mne.Epochs`, or `~mne.preprocessing.ICA` figure if a ``scalings`` parameter is provided. This also speeds up butterfly plot generation in :meth:`mne.Report.add_raw` (:gh:`10109` by `Richard Höchenberger`_ and `Eric Larson`_)
hoechenberger marked this conversation as resolved.
Show resolved Hide resolved

- :meth:`mne.Report.add_raw` gained a new ``scalings`` parameter to provide custom data scalings for the butterfly plots (:gh:`10109` by `Richard Höchenberger`_)


Bugs
~~~~

Expand Down
27 changes: 19 additions & 8 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..viz import (plot_events, plot_alignment, plot_cov, plot_projs_topomap,
plot_compare_evokeds, set_3d_view, get_3d_backend)
from ..viz.misc import _plot_mri_contours, _get_bem_plotting_surfaces
from ..viz.utils import _ndarray_to_fig, tight_layout
from ..viz.utils import _ndarray_to_fig, tight_layout, _compute_scalings
from ..forward import read_forward_solution, Forward
from ..epochs import read_epochs, BaseEpochs
from ..preprocessing.ica import read_ica
Expand Down Expand Up @@ -980,7 +980,8 @@ def add_evokeds(self, evokeds, *, titles=None, noise_cov=None, projs=None,

@fill_doc
def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
tags=('raw',), replace=False, topomap_kwargs=None):
scalings=None, tags=('raw',), replace=False,
topomap_kwargs=None):
"""Add `~mne.io.Raw` objects to the report.

Parameters
Expand All @@ -997,6 +998,7 @@ def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
butterfly : bool
Whether to add a butterfly plot of the (decimated) data. Can be
useful to spot segments marked as "bad" and problematic channels.
%(scalings)s
%(report_tags)s
%(report_replace)s
%(topomap_kwargs)s
Expand All @@ -1021,6 +1023,7 @@ def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
add_psd=add_psd,
add_projs=add_projs,
add_butterfly=butterfly,
butterfly_scalings=scalings,
image_format=self.image_format,
tags=tags,
topomap_kwargs=topomap_kwargs,
Expand Down Expand Up @@ -2545,17 +2548,24 @@ def _render_one_bem_axis(self, *, mri_fname, surfaces,

return html

def _render_raw_butterfly_segments(self, *, raw: BaseRaw, image_format,
tags):
def _render_raw_butterfly_segments(
self, *, raw: BaseRaw, scalings, image_format, tags
):
# Pick 10 1-second time slices
times = np.linspace(raw.times[0], raw.times[-1], 12)[1:-1]
scalings = _compute_scalings(
scalings, inst=raw, remove_dc=True, duration=1
)
scalings = _handle_default('scalings_plot_raw', scalings)
figs = []
for t in times:
tmin = max(t - 0.5, 0)
tmax = min(t + 0.5, raw.times[-1])
duration = tmax - tmin
fig = raw.plot(butterfly=True, show_scrollbars=False, start=tmin,
duration=duration, show=False)
fig = raw.plot(
butterfly=True, show_scrollbars=False, start=tmin,
duration=duration, scalings=scalings, show=False
)
figs.append(fig)

captions = [f'Segment {i+1} of {len(figs)}'
Expand All @@ -2569,7 +2579,7 @@ def _render_raw_butterfly_segments(self, *, raw: BaseRaw, image_format,
return html

def _render_raw(self, *, raw, add_psd, add_projs, add_butterfly,
image_format, tags, topomap_kwargs):
butterfly_scalings, image_format, tags, topomap_kwargs):
"""Render raw."""
if isinstance(raw, BaseRaw):
fname = raw.filenames[0]
Expand All @@ -2593,7 +2603,8 @@ def _render_raw(self, *, raw, add_psd, add_projs, add_butterfly,
# Butterfly plot
if add_butterfly:
butterfly_imgs_html = self._render_raw_butterfly_segments(
raw=raw, image_format=image_format, tags=tags
raw=raw, scalings=butterfly_scalings,
image_format=image_format, tags=tags
)
else:
butterfly_imgs_html = ''
Expand Down
26 changes: 23 additions & 3 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@
End time of the raw data to use in seconds (cannot exceed data duration).
"""

# Scaling of traces in plots
docdict['scalings'] = """
scalings : 'auto' | dict | None
Scaling factors for the traces. If a dictionary where any
value is ``'auto'``, the scaling factor is set to match the 99.5th
percentile of the respective data. If ``'auto'``, all scalings (for all
channel types) are set to ``'auto'``. If any values are ``'auto'`` and the
data is not preloaded, a subset up to 100 MB will be loaded. If ``None``,
defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
resp=1, chpi=1e-4, whitened=1e2)

.. note::
A particular scaling value ``s`` corresponds to half of the visualized
signal range around zero (i.e. from ``0`` to ``+s`` or from ``0`` to
``-s``). For example, the default scaling of ``20e-6`` (20µV) for EEG
signals means that the visualized range will be 40 µV (20 µV in the
positive direction and 20 µV in the negative direction).
"""

# Raw
docdict['standardize_names'] = """
standardize_names : bool
Expand Down Expand Up @@ -2364,9 +2386,7 @@
"""
docdict['topomap_kwargs'] = """
topomap_kwargs : dict | None
Keyword arguments to pass to topomap functions (
:func:`mne.viz.plot_evoked_topomap`, :func:`mne.viz.plot_projs_topomap`,
etc.).
Keyword arguments to pass to the topomap-generating functions.
"""

# Epochs
Expand Down
12 changes: 1 addition & 11 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,17 +652,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20,
epochs : instance of Epochs
The epochs object.
%(picks_good_data)s
scalings : dict | 'auto' | None
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded,
a subset of epochs up to 100 Mb will be loaded. If None, defaults to::
dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4,
whitened=10.)
%(scalings)s
n_epochs : int
The number of epochs per view. Defaults to 20.
n_channels : int
Expand Down
18 changes: 1 addition & 17 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
Color to make bad channels.
%(event_color)s
Defaults to ``'cyan'``.
scalings : 'auto' | dict | None
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded, a
subset of times up to 100mb will be loaded. If None, defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
resp=1, chpi=1e-4, whitened=1e2)

A particular scaling value ``s`` corresponds to half of the visualized
signal range around zero (i.e. from ``0`` to ``+s`` or from ``0`` to
``-s``). For example, the default scaling of ``20e-6`` (20µV) for EEG
signals means that the visualized range will be 40µV (20µV in the
positive direction and 20µV in the negative direction).

%(scalings)s
remove_dc : bool
If True remove DC component when plotting data.
order : array of int | None
Expand Down
5 changes: 5 additions & 0 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,11 @@ def _compute_scalings(scalings, inst, remove_dc=False, duration=10):
if not isinstance(inst, (BaseRaw, BaseEpochs)):
raise ValueError('Must supply either Raw or Epochs')

# If there are no "auto" scalings, we can return early!
if all([scalings[ch_type] != 'auto'
for ch_type in set(inst.get_channel_types())]):
hoechenberger marked this conversation as resolved.
Show resolved Hide resolved
return scalings

ch_types = channel_indices_by_type(inst.info)
ch_types = {i_type: i_ixs
for i_type, i_ixs in ch_types.items() if len(i_ixs) != 0}
Expand Down