Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: Speed up browse figure creation if scalings are provided #10109

Merged
merged 8 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ Enhancements

- Add ``infer_type`` argument to :func:`mne.io.read_raw_edf` and :func:`mne.io.read_raw_bdf` to automatically infer channel types from channel labels (:gh:`10058` by `Clemens Brunner`_)

- Reduce the time it takes to generate a :class:`mne.io.Raw`, :class:`~mne.Epochs`, or :class:`~mne.preprocessing.ICA` figure if a ``scalings`` parameter is provided. This also speeds up butterfly plot generation in :meth:`mne.Report.add_raw` (:gh:`10109` by `Richard Höchenberger`_ and `Eric Larson`_)

- :meth:`mne.Report.add_raw` gained a new ``scalings`` parameter to provide custom data scalings for the butterfly plots (:gh:`10109` by `Richard Höchenberger`_)


Bugs
~~~~

Expand Down
27 changes: 19 additions & 8 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..viz import (plot_events, plot_alignment, plot_cov, plot_projs_topomap,
plot_compare_evokeds, set_3d_view, get_3d_backend)
from ..viz.misc import _plot_mri_contours, _get_bem_plotting_surfaces
from ..viz.utils import _ndarray_to_fig, tight_layout
from ..viz.utils import _ndarray_to_fig, tight_layout, _compute_scalings
from ..forward import read_forward_solution, Forward
from ..epochs import read_epochs, BaseEpochs
from ..preprocessing.ica import read_ica
Expand Down Expand Up @@ -980,7 +980,8 @@ def add_evokeds(self, evokeds, *, titles=None, noise_cov=None, projs=None,

@fill_doc
def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
tags=('raw',), replace=False, topomap_kwargs=None):
scalings=None, tags=('raw',), replace=False,
topomap_kwargs=None):
"""Add `~mne.io.Raw` objects to the report.

Parameters
Expand All @@ -997,6 +998,7 @@ def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
butterfly : bool
Whether to add a butterfly plot of the (decimated) data. Can be
useful to spot segments marked as "bad" and problematic channels.
%(scalings)s
%(report_tags)s
%(report_replace)s
%(topomap_kwargs)s
Expand All @@ -1021,6 +1023,7 @@ def add_raw(self, raw, title, *, psd=None, projs=None, butterfly=True,
add_psd=add_psd,
add_projs=add_projs,
add_butterfly=butterfly,
butterfly_scalings=scalings,
image_format=self.image_format,
tags=tags,
topomap_kwargs=topomap_kwargs,
Expand Down Expand Up @@ -2545,17 +2548,24 @@ def _render_one_bem_axis(self, *, mri_fname, surfaces,

return html

def _render_raw_butterfly_segments(self, *, raw: BaseRaw, image_format,
tags):
def _render_raw_butterfly_segments(
self, *, raw: BaseRaw, scalings, image_format, tags
):
# Pick 10 1-second time slices
times = np.linspace(raw.times[0], raw.times[-1], 12)[1:-1]
scalings = _compute_scalings(
scalings, inst=raw, remove_dc=True, duration=1
)
scalings = _handle_default('scalings_plot_raw', scalings)
figs = []
for t in times:
tmin = max(t - 0.5, 0)
tmax = min(t + 0.5, raw.times[-1])
duration = tmax - tmin
fig = raw.plot(butterfly=True, show_scrollbars=False, start=tmin,
duration=duration, show=False)
fig = raw.plot(
butterfly=True, show_scrollbars=False, start=tmin,
duration=duration, scalings=scalings, show=False
)
figs.append(fig)

captions = [f'Segment {i+1} of {len(figs)}'
Expand All @@ -2569,7 +2579,7 @@ def _render_raw_butterfly_segments(self, *, raw: BaseRaw, image_format,
return html

def _render_raw(self, *, raw, add_psd, add_projs, add_butterfly,
image_format, tags, topomap_kwargs):
butterfly_scalings, image_format, tags, topomap_kwargs):
"""Render raw."""
if isinstance(raw, BaseRaw):
fname = raw.filenames[0]
Expand All @@ -2593,7 +2603,8 @@ def _render_raw(self, *, raw, add_psd, add_projs, add_butterfly,
# Butterfly plot
if add_butterfly:
butterfly_imgs_html = self._render_raw_butterfly_segments(
raw=raw, image_format=image_format, tags=tags
raw=raw, scalings=butterfly_scalings,
image_format=image_format, tags=tags
)
else:
butterfly_imgs_html = ''
Expand Down
26 changes: 23 additions & 3 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@
End time of the raw data to use in seconds (cannot exceed data duration).
"""

# Scaling of traces in plots
docdict['scalings'] = """
scalings : 'auto' | dict | None
Scaling factors for the traces. If a dictionary where any
value is ``'auto'``, the scaling factor is set to match the 99.5th
percentile of the respective data. If ``'auto'``, all scalings (for all
channel types) are set to ``'auto'``. If any values are ``'auto'`` and the
data is not preloaded, a subset up to 100 MB will be loaded. If ``None``,
defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
resp=1, chpi=1e-4, whitened=1e2)

.. note::
A particular scaling value ``s`` corresponds to half of the visualized
signal range around zero (i.e. from ``0`` to ``+s`` or from ``0`` to
``-s``). For example, the default scaling of ``20e-6`` (20µV) for EEG
signals means that the visualized range will be 40 µV (20 µV in the
positive direction and 20 µV in the negative direction).
"""

# Raw
docdict['standardize_names'] = """
standardize_names : bool
Expand Down Expand Up @@ -2364,9 +2386,7 @@
"""
docdict['topomap_kwargs'] = """
topomap_kwargs : dict | None
Keyword arguments to pass to topomap functions (
:func:`mne.viz.plot_evoked_topomap`, :func:`mne.viz.plot_projs_topomap`,
etc.).
Keyword arguments to pass to the topomap-generating functions.
"""

# Epochs
Expand Down
12 changes: 1 addition & 11 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,17 +652,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20,
epochs : instance of Epochs
The epochs object.
%(picks_good_data)s
scalings : dict | 'auto' | None
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded,
a subset of epochs up to 100 Mb will be loaded. If None, defaults to::
dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4,
whitened=10.)
%(scalings)s
n_epochs : int
The number of epochs per view. Defaults to 20.
n_channels : int
Expand Down
18 changes: 1 addition & 17 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
Color to make bad channels.
%(event_color)s
Defaults to ``'cyan'``.
scalings : 'auto' | dict | None
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded, a
subset of times up to 100mb will be loaded. If None, defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
resp=1, chpi=1e-4, whitened=1e2)

A particular scaling value ``s`` corresponds to half of the visualized
signal range around zero (i.e. from ``0`` to ``+s`` or from ``0`` to
``-s``). For example, the default scaling of ``20e-6`` (20µV) for EEG
signals means that the visualized range will be 40µV (20µV in the
positive direction and 20µV in the negative direction).

%(scalings)s
remove_dc : bool
If True remove DC component when plotting data.
order : array of int | None
Expand Down
31 changes: 21 additions & 10 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,14 +1158,33 @@ def _compute_scalings(scalings, inst, remove_dc=False, duration=10):
"""
from ..io.base import BaseRaw
from ..epochs import BaseEpochs

scalings = deepcopy(scalings)
Copy link
Member

@larsoner larsoner Dec 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need this line. _handle_default should EDIT: already make a copy internally

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(at least if you're worried about the user passing a dict and it being modified, it shouldn't be -- the values in principle could be if they're mutable and modified inplace but in practice this shouldn't really happen I think)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the deepcopy now.

scalings = _handle_default('scalings_plot_raw', scalings)
if not isinstance(inst, (BaseRaw, BaseEpochs)):
raise ValueError('Must supply either Raw or Epochs')

for key, value in scalings.items():
if isinstance(value, str) and value != 'auto':
Copy link
Member

@larsoner larsoner Dec 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems wrong now as the not was omitted, should probably still be

Suggested change
if isinstance(value, str) and value != 'auto':
if not (isinstance(value, str) and value == 'auto'):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... if you were trying to distribute the not, you probably meant:

Suggested change
if isinstance(value, str) and value != 'auto':
if not isinstance(value, str) or value != 'auto':

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I believe we could probably change this this to

        if value != 'auto' and not isinstance(value, float):

WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed a respective change in cbc14e8

Copy link
Member

@larsoner larsoner Dec 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the old conditional better, for example:

>>> import numpy as np
>>> np.array([1.]).item()
1.0
>>> float(np.array([1.]))
1.0
>>> np.array([1.]) == 'auto'
<stdin>:1: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
False

That warning would not happen with the isinstance(thing, str) call version of the code

To me the logic of if not (string and equal to 'auto'): then convert to float is clearer and more general

try:
scalings[key] = float(value)
except Exception:
raise ValueError(
f'scalings must be "auto" or float, got '
f'scalings[{key!r}]={value!r} which could not be '
f'converted to float'
)

# If there are no "auto" scalings, we can return early!
if all(
[scalings[ch_type] != 'auto'
for ch_type in inst.get_channel_types(unique=True)]
):
return scalings

ch_types = channel_indices_by_type(inst.info)
ch_types = {i_type: i_ixs
for i_type, i_ixs in ch_types.items() if len(i_ixs) != 0}
scalings = deepcopy(scalings)

if inst.preload is False:
if isinstance(inst, BaseRaw):
Expand All @@ -1191,15 +1210,7 @@ def _compute_scalings(scalings, inst, remove_dc=False, duration=10):
data = inst._data.swapaxes(0, 1).reshape([len(inst.ch_names), -1])
# Iterate through ch types and update scaling if ' auto'
for key, value in scalings.items():
if key not in ch_types:
continue
if not (isinstance(value, str) and value == 'auto'):
try:
scalings[key] = float(value)
except Exception:
raise ValueError(
f'scalings must be "auto" or float, got scalings[{key!r}]='
f'{value!r} which could not be converted to float')
if (key not in ch_types or value != 'auto'):
hoechenberger marked this conversation as resolved.
Show resolved Hide resolved
continue
this_data = data[ch_types[key]]
if remove_dc and (this_data.shape[1] / inst.info["sfreq"] >= duration):
Expand Down