Skip to content

Commit

Permalink
MRG: Speed up browse figure creation if scalings are provided (#10109)
Browse files Browse the repository at this point in the history
* Omit unnecessary scalings computations

* Fix autoscale behavior, apply review suggestion

* Update mne/viz/utils.py

* Better logic

* Revert logic

* Drop deepcopy

* Drop unused import
  • Loading branch information
hoechenberger committed Dec 8, 2021
1 parent 151b8a6 commit 6a26583
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 50 deletions.
5 changes: 5 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ Enhancements

- Add ``infer_type`` argument to :func:`mne.io.read_raw_edf` and :func:`mne.io.read_raw_bdf` to automatically infer channel types from channel labels (:gh:`10058` by `Clemens Brunner`_)

- Reduce the time it takes to generate a :class:`mne.io.Raw`, :class:`~mne.Epochs`, or :class:`~mne.preprocessing.ICA` figure if a ``scalings`` parameter is provided. This also speeds up butterfly plot generation in :meth:`mne.Report.add_raw` (:gh:`10109` by `Richard Höchenberger`_ and `Eric Larson`_)

- :meth:`mne.Report.add_raw` gained a new ``scalings`` parameter to provide custom data scalings for the butterfly plots (:gh:`10109` by `Richard Höchenberger`_)


Bugs
~~~~

Expand Down
27 changes: 19 additions & 8 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..viz import (plot_events, plot_alignment, plot_cov, plot_projs_topomap,
plot_compare_evokeds, set_3d_view, get_3d_backend)
from ..viz.misc import _plot_mri_contours, _get_bem_plotting_surfaces
from ..viz.utils import _ndarray_to_fig, tight_layout
from ..viz.utils import _ndarray_to_fig, tight_layout, _compute_scalings
from ..forward import read_forward_solution, Forward
from ..epochs import read_epochs, BaseEpochs
from ..preprocessing.ica import read_ica
Expand Down Expand Up @@ -980,7 +980,8 @@ def add_evokeds(self, evokeds, *, titles=None, noise_cov=None, projs=None,

@fill_doc
def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
tags=('raw',), replace=False, topomap_kwargs=None):
scalings=None, tags=('raw',), replace=False,
topomap_kwargs=None):
"""Add `~mne.io.Raw` objects to the report.
Parameters
Expand All @@ -997,6 +998,7 @@ def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
butterfly : bool
Whether to add a butterfly plot of the (decimated) data. Can be
useful to spot segments marked as "bad" and problematic channels.
%(scalings)s
%(report_tags)s
%(report_replace)s
%(topomap_kwargs)s
Expand All @@ -1021,6 +1023,7 @@ def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
add_psd=add_psd,
add_projs=add_projs,
add_butterfly=butterfly,
butterfly_scalings=scalings,
image_format=self.image_format,
tags=tags,
topomap_kwargs=topomap_kwargs,
Expand Down Expand Up @@ -2545,17 +2548,24 @@ def _render_one_bem_axis(self, *, mri_fname, surfaces,

return html

def _render_raw_butterfly_segments(self, *, raw: BaseRaw, image_format,
tags):
def _render_raw_butterfly_segments(
self, *, raw: BaseRaw, scalings, image_format, tags
):
# Pick 10 1-second time slices
times = np.linspace(raw.times[0], raw.times[-1], 12)[1:-1]
scalings = _compute_scalings(
scalings, inst=raw, remove_dc=True, duration=1
)
scalings = _handle_default('scalings_plot_raw', scalings)
figs = []
for t in times:
tmin = max(t - 0.5, 0)
tmax = min(t + 0.5, raw.times[-1])
duration = tmax - tmin
fig = raw.plot(butterfly=True, show_scrollbars=False, start=tmin,
duration=duration, show=False)
fig = raw.plot(
butterfly=True, show_scrollbars=False, start=tmin,
duration=duration, scalings=scalings, show=False
)
figs.append(fig)

captions = [f'Segment {i+1} of {len(figs)}'
Expand All @@ -2569,7 +2579,7 @@ def _render_raw_butterfly_segments(self, *, raw: BaseRaw, image_format,
return html

def _render_raw(self, *, raw, add_psd, add_projs, add_butterfly,
image_format, tags, topomap_kwargs):
butterfly_scalings, image_format, tags, topomap_kwargs):
"""Render raw."""
if isinstance(raw, BaseRaw):
fname = raw.filenames[0]
Expand All @@ -2593,7 +2603,8 @@ def _render_raw(self, *, raw, add_psd, add_projs, add_butterfly,
# Butterfly plot
if add_butterfly:
butterfly_imgs_html = self._render_raw_butterfly_segments(
raw=raw, image_format=image_format, tags=tags
raw=raw, scalings=butterfly_scalings,
image_format=image_format, tags=tags
)
else:
butterfly_imgs_html = ''
Expand Down
26 changes: 23 additions & 3 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@
End time of the raw data to use in seconds (cannot exceed data duration).
"""

# Scaling of traces in plots
docdict['scalings'] = """
scalings : 'auto' | dict | None
Scaling factors for the traces. If a dictionary where any
value is ``'auto'``, the scaling factor is set to match the 99.5th
percentile of the respective data. If ``'auto'``, all scalings (for all
channel types) are set to ``'auto'``. If any values are ``'auto'`` and the
data is not preloaded, a subset up to 100 MB will be loaded. If ``None``,
defaults to::
dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
resp=1, chpi=1e-4, whitened=1e2)
.. note::
A particular scaling value ``s`` corresponds to half of the visualized
signal range around zero (i.e. from ``0`` to ``+s`` or from ``0`` to
``-s``). For example, the default scaling of ``20e-6`` (20µV) for EEG
signals means that the visualized range will be 40 µV (20 µV in the
positive direction and 20 µV in the negative direction).
"""

# Raw
docdict['standardize_names'] = """
standardize_names : bool
Expand Down Expand Up @@ -2364,9 +2386,7 @@
"""
docdict['topomap_kwargs'] = """
topomap_kwargs : dict | None
Keyword arguments to pass to topomap functions (
:func:`mne.viz.plot_evoked_topomap`, :func:`mne.viz.plot_projs_topomap`,
etc.).
Keyword arguments to pass to the topomap-generating functions.
"""

# Epochs
Expand Down
12 changes: 1 addition & 11 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,17 +652,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20,
epochs : instance of Epochs
The epochs object.
%(picks_good_data)s
scalings : dict | 'auto' | None
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded,
a subset of epochs up to 100 Mb will be loaded. If None, defaults to::
dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4,
whitened=10.)
%(scalings)s
n_epochs : int
The number of epochs per view. Defaults to 20.
n_channels : int
Expand Down
18 changes: 1 addition & 17 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
Color to make bad channels.
%(event_color)s
Defaults to ``'cyan'``.
scalings : 'auto' | dict | None
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded, a
subset of times up to 100mb will be loaded. If None, defaults to::
dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
resp=1, chpi=1e-4, whitened=1e2)
A particular scaling value ``s`` corresponds to half of the visualized
signal range around zero (i.e. from ``0`` to ``+s`` or from ``0`` to
``-s``). For example, the default scaling of ``20e-6`` (20µV) for EEG
signals means that the visualized range will be 40µV (20µV in the
positive direction and 20µV in the negative direction).
%(scalings)s
remove_dc : bool
If True remove DC component when plotting data.
order : array of int | None
Expand Down
31 changes: 20 additions & 11 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import tempfile
import math
import numpy as np
from copy import deepcopy
import warnings
from datetime import datetime

Expand Down Expand Up @@ -1158,14 +1157,32 @@ def _compute_scalings(scalings, inst, remove_dc=False, duration=10):
"""
from ..io.base import BaseRaw
from ..epochs import BaseEpochs

scalings = _handle_default('scalings_plot_raw', scalings)
if not isinstance(inst, (BaseRaw, BaseEpochs)):
raise ValueError('Must supply either Raw or Epochs')

for key, value in scalings.items():
if not (isinstance(value, str) and value == 'auto'):
try:
scalings[key] = float(value)
except Exception:
raise ValueError(
f'scalings must be "auto" or float, got '
f'scalings[{key!r}]={value!r} which could not be '
f'converted to float'
)

# If there are no "auto" scalings, we can return early!
if all(
[scalings[ch_type] != 'auto'
for ch_type in inst.get_channel_types(unique=True)]
):
return scalings

ch_types = channel_indices_by_type(inst.info)
ch_types = {i_type: i_ixs
for i_type, i_ixs in ch_types.items() if len(i_ixs) != 0}
scalings = deepcopy(scalings)

if inst.preload is False:
if isinstance(inst, BaseRaw):
Expand All @@ -1191,15 +1208,7 @@ def _compute_scalings(scalings, inst, remove_dc=False, duration=10):
data = inst._data.swapaxes(0, 1).reshape([len(inst.ch_names), -1])
# Iterate through ch types and update scaling if ' auto'
for key, value in scalings.items():
if key not in ch_types:
continue
if not (isinstance(value, str) and value == 'auto'):
try:
scalings[key] = float(value)
except Exception:
raise ValueError(
f'scalings must be "auto" or float, got scalings[{key!r}]='
f'{value!r} which could not be converted to float')
if key not in ch_types or value != 'auto':
continue
this_data = data[ch_types[key]]
if remove_dc and (this_data.shape[1] / inst.info["sfreq"] >= duration):
Expand Down

0 comments on commit 6a26583

Please sign in to comment.