From 8f6af8068fa2281d9b6bc207bd46ae705a31eb2d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 7 Dec 2022 16:42:32 -0500 Subject: [PATCH] ENH: Add webp support to Report (#11359) --- doc/changes/latest.inc | 2 + mne/report/report.py | 80 +++++++++++++++++++++++++-------- mne/report/tests/test_report.py | 45 +++++++++++++++++-- tutorials/intro/70_report.py | 4 +- 4 files changed, 107 insertions(+), 24 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index dc2f79318fe..09707b324be 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -28,6 +28,7 @@ Enhancements - Add a warning to the docstring of :func:`mne.channels.find_ch_adjacency` to encourage users to validate their outputs (:gh:`11236` by `Felix Klotzsche`_ and `Eric Larson`_) - Mixed, cortical + discrete source spaces with fixed orientations are now allowed. (:gh:`11241` by `Jevri Hanna`_) - Add size information to the ``repr`` of :class:`mne.Report` (:gh:`11357` by `Eric Larson`_) +- Add support for ``image_format='webp'`` to :class:`mne.Report` when using Matplotlib 3.6+, which can reduce file sizes by up to 50% compared to ``'png'``. The new default ``image_format='auto'`` will automatically use this format if it's available on the system (:gh:`11359` by `Eric Larson`_) - Add :func:`mne.beamformer.apply_dics_tfr_epochs` to apply a DICS beamformer to time-frequency resolved epochs (:gh:`11096` by `Alex Rockhill`_) - Check whether head radius (estimated from channel positions) is correct when reading EEGLAB data with :func:`~mne.io.read_raw_eeglab` and :func:`~mne.read_epochs_eeglab`. If head radius is not within likely values, warn informing about possible units mismatch and the new ``montage_units`` argument (:gh:`11283` by `MikoĊ‚aj Magnuski`_). - Add support for a callable passed in ``combine`` for `mne.time_frequency.AverageTFR.plot` and `mne.time_frequency.AverageTFR.plot_joint` (:gh:`11329` by `Mathieu Scheltienne`_) @@ -55,6 +56,7 @@ Bugs - Fix bug with :class:`mne.Report` with ``replace=True`` where the wrong content was replaced and ``section`` was not respected (:gh:`11318`, :gh:`11346` by `Eric Larson`_) - Fix bug with unit conversion when setting reference MEG as the channel type in :meth:`mne.io.Raw.set_channel_types` and related methods (:gh:`11344` by `Eric Larson`_) - Fix bug where reference MEG channels could not be plotted using :func:`mne.viz.plot_epochs_image` (:gh:`11344` by `Eric Larson`_) +- Fix bug where ``image_format='gif'`` was errantly documented as being supported by :class:`mne.Report`, it is now only supported in :meth:`mne.Report.add_image` (:gh:`11347` by `Eric Larson`_) - Multitaper spectral estimation now uses periodic (rather than symmetric) taper windows. This also necessitated changing the default ``max_iter`` of our cross-spectral density functions from 150 to 250. (:gh:`11293` by `Daniel McCloy`_) - Fix :meth:`mne.Epochs.plot_image` and :func:`mne.viz.plot_epochs_image` when using EMG signals (:gh:`11322` by `Alex Gramfort`_) diff --git a/mne/report/report.py b/mne/report/report.py index 708ea4f334d..d2f2db8568c 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -26,7 +26,6 @@ import numpy as np from .. import __version__ as MNE_VERSION -from ..fixes import _compare_version from .. import (read_evokeds, read_events, read_cov, read_source_estimate, read_trans, sys_info, Evoked, SourceEstimate, Covariance, Info, Transform) @@ -40,7 +39,8 @@ from ..utils import (logger, verbose, get_subjects_dir, warn, _ensure_int, fill_doc, _check_option, _validate_type, _safe_input, _path_like, use_log_level, _check_fname, _pl, - _check_ch_locs, _import_h5io_funcs, _verbose_safe_false) + _check_ch_locs, _import_h5io_funcs, _verbose_safe_false, + check_version) from ..viz import (plot_events, plot_alignment, plot_cov, plot_projs_topomap, plot_compare_evokeds, set_3d_view, get_3d_backend, Figure3D, use_browser_backend) @@ -345,21 +345,46 @@ def _fig_to_img(fig, *, image_format='png', own_figure=True): own_figure = True # close the fig we just created output = BytesIO() + dpi = fig.get_dpi() logger.debug( f'Saving figure with dimension {fig.get_size_inches()} inches with ' - f'{fig.get_dpi()} dpi' + f'{{dpi}} dpi' ) + # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html + mpl_kwargs = dict() + pil_kwargs = dict() + has_pillow = check_version('PIL') + if has_pillow: + if image_format == 'webp': + pil_kwargs.update(lossless=True, method=6) + elif image_format == 'png': + pil_kwargs.update(optimize=True, compress_level=9) + if pil_kwargs: + # matplotlib modifies the passed dict, which is a bug + mpl_kwargs['pil_kwargs'] = pil_kwargs.copy() with warnings.catch_warnings(): warnings.filterwarnings( action='ignore', message='.*Axes that are not compatible with tight_layout.*', category=UserWarning ) - fig.savefig(output, format=image_format, dpi=fig.get_dpi()) + fig.savefig(output, format=image_format, dpi=dpi, **mpl_kwargs) if own_figure: plt.close(fig) + + # Remove alpha + if image_format != 'svg' and has_pillow: + from PIL import Image + output.seek(0) + orig = Image.open(output) + if orig.mode == 'RGBA': + background = Image.new('RGBA', orig.size, (255, 255, 255)) + new = Image.alpha_composite(background, orig).convert('RGB') + output = BytesIO() + new.save(output, format=image_format, dpi=(dpi, dpi), **pil_kwargs) + output = output.getvalue() return (output.decode('utf-8') if image_format == 'svg' else base64.b64encode(output).decode('ascii')) @@ -405,9 +430,8 @@ def _get_bem_contour_figs_as_arrays( A list of NumPy arrays that represent the generated Matplotlib figures. """ # Matplotlib <3.2 doesn't work nicely with process-based parallelization - from matplotlib import __version__ as MPL_VERSION kwargs = dict() - if _compare_version(MPL_VERSION, '<', '3.2'): + if not check_version('matplotilb', '3.2'): kwargs['prefer'] = 'threads' parallel, p_fun, n_jobs = parallel_func( @@ -598,6 +622,16 @@ def open_report(fname, **params): mne_logo_path = Path(__file__).parents[1] / 'icons' / 'mne_icon-cropped.png' mne_logo = base64.b64encode(mne_logo_path.read_bytes()).decode('ascii') +_ALLOWED_IMAGE_FORMATS = ('png', 'svg', 'webp') + + +def _webp_supported(): + good = check_version('matplotlib', '3.6') and check_version('PIL') + if good: + from PIL import features + good = features.check('webp') + return good + def _check_scale(scale): """Ensure valid scale value is passed.""" @@ -608,10 +642,17 @@ def _check_scale(scale): def _check_image_format(rep, image_format): """Ensure fmt is valid.""" if rep is None or image_format is not None: - _check_option('image_format', image_format, - allowed_values=('png', 'svg', 'gif')) + allowed = list(_ALLOWED_IMAGE_FORMATS) + ['auto'] + extra = '' + if not _webp_supported(): + allowed.pop(allowed.index('webp')) + extra = '("webp" supported on matplotlib 3.6+ with PIL installed)' + _check_option( + 'image_format', image_format, allowed_values=allowed, extra=extra) else: image_format = rep.image_format + if image_format == 'auto': + image_format = 'webp' if _webp_supported() else 'png' return image_format @@ -632,13 +673,17 @@ class Report: Name of the file containing the noise covariance. %(baseline_report)s Defaults to ``None``, i.e. no baseline correction. - image_format : 'png' | 'svg' | 'gif' - Default image format to use (default is ``'png'``). + image_format : 'png' | 'svg' | 'webp' | 'auto' + Default image format to use (default is ``'auto'``, which will use + ``'webp'`` if available and ``'png'`` otherwise). ``'svg'`` uses vector graphics, so fidelity is higher but can increase file size and browser image rendering time as well. + ``'webp'`` format requires matplotlib >= 3.6. .. versionadded:: 0.15 - + .. versionchanged:: 1.3 + Added support for ``'webp'`` format, removed support for GIF, and + set the default to ``'auto'``. raw_psd : bool | dict If True, include PSD plots for raw files. Can be False (default) to omit, True to plot, or a dict to pass as ``kwargs`` to @@ -669,7 +714,6 @@ class Report: Default image format to use. .. versionadded:: 0.15 - raw_psd : bool | dict If True, include PSD plots for raw files. Can be False (default) to omit, True to plot, or a dict to pass as ``kwargs`` to @@ -703,7 +747,7 @@ class Report: @verbose def __init__(self, info_fname=None, subjects_dir=None, subject=None, title=None, cov_fname=None, baseline=None, - image_format='png', raw_psd=False, projs=False, *, + image_format='auto', raw_psd=False, projs=False, *, verbose=None): self.info_fname = str(info_fname) if info_fname is not None else None self.cov_fname = str(cov_fname) if cov_fname is not None else None @@ -2039,14 +2083,11 @@ def add_image( .. versionadded:: 0.24.0 """ tags = _check_tags(tags) - img_bytes = Path(image).expanduser().read_bytes() - img_base64 = base64.b64encode(img_bytes).decode('ascii') - del img_bytes # Free memory - + image = Path(_check_fname(image, overwrite='read', must_exist=True)) img_format = Path(image).suffix.lower()[1:] # omit leading period _check_option('Image format', value=img_format, - allowed_values=('png', 'gif', 'svg')) - + allowed_values=list(_ALLOWED_IMAGE_FORMATS) + ['gif']) + img_base64 = base64.b64encode(image.read_bytes()).decode('ascii') self._add_image( img=img_base64, title=title, @@ -2156,6 +2197,7 @@ def _render_slider( imgs = [_fig_to_img(fig=fig, image_format=image_format, own_figure=own_figure) for fig in figs] + dom_id = self._get_dom_id() html = _html_slider_element( id=dom_id, diff --git a/mne/report/tests/test_report.py b/mne/report/tests/test_report.py index 0ee4eb3b181..0bd7b1d1d8d 100644 --- a/mne/report/tests/test_report.py +++ b/mne/report/tests/test_report.py @@ -22,7 +22,9 @@ from mne import (Epochs, read_events, read_evokeds, read_cov, pick_channels_cov, create_info) from mne.report import report as report_mod -from mne.report.report import CONTENT_ORDER +from mne.report.report import ( + CONTENT_ORDER, _ALLOWED_IMAGE_FORMATS, _webp_supported, +) from mne.io import read_raw_fif, read_info, RawArray from mne.datasets import testing from mne.report import Report, open_report, _ReportScraper, report @@ -394,7 +396,7 @@ def test_render_add_sections(renderer, tmp_path): fig.savefig(img_fname) report.add_image(image=img_fname, title='evoked response') - with pytest.raises(FileNotFoundError, match='No such file or directory'): + with pytest.raises(FileNotFoundError, match='does not exist'): report.add_image(image='foobar.xxx', title='H') evoked = read_evokeds(evoked_fname, condition='Left Auditory', @@ -452,7 +454,7 @@ def test_add_bem_n_jobs(n_jobs, monkeypatch): use_subjects_dir = None else: use_subjects_dir = subjects_dir - report = Report(subjects_dir=use_subjects_dir) + report = Report(subjects_dir=use_subjects_dir, image_format='png') # implicitly test that subjects_dir is correctly preserved here monkeypatch.setattr(report_mod, '_BEM_VIEWS', ('axial',)) if use_subjects_dir is not None: @@ -996,3 +998,40 @@ def test_tags(tags, str_or_array, wrong_dtype, invalid_chars): r.add_code(code='foo', title='bar', tags=tags) else: r.add_code(code='foo', title='bar', tags=tags) + + +# These are all the ones we claim to support +@pytest.mark.parametrize('image_format', _ALLOWED_IMAGE_FORMATS) +def test_image_format(image_format): + """Test image format support.""" + if image_format == 'webp': + if not _webp_supported(): + with pytest.raises(ValueError, match='matplotlib'): + Report(image_format='webp') + return + r = Report(image_format=image_format) + fig1, _ = _get_example_figures() + r.add_figure(fig1, 'fig1') + assert image_format in r.html[0] + + +def test_gif(tmp_path): + """Test that GIFs can be embedded using add_image.""" + pytest.importorskip('PIL') + from PIL import Image + sequence = [ + Image.fromarray(frame.astype(np.uint8)) + for frame in _get_example_figures() + ] + fname = tmp_path / 'test.gif' + sequence[0].save(str(fname), save_all=True, append_images=sequence[1:]) + assert fname.is_file() + with pytest.raises(ValueError, match='Allowed values'): + Report(image_format='gif') + r = Report() + r.add_image(fname, 'fname') + assert 'image/gif' in r.html[0] + bad_name = fname.with_suffix('.foo') + bad_name.write_bytes(b'') + with pytest.raises(ValueError, match='Allowed values'): + r.add_image(bad_name, 'fname') diff --git a/tutorials/intro/70_report.py b/tutorials/intro/70_report.py index 7c484cb4e8f..c399fcaca3e 100644 --- a/tutorials/intro/70_report.py +++ b/tutorials/intro/70_report.py @@ -101,8 +101,8 @@ # ^^^^^^^^^^^^^ # # Events can be added via :meth:`mne.Report.add_events`. You also need to -# supply the sampling frequency used during the recording; this information -# is used to generate a meaningful time axis. +# supply the sampling frequency used during the recording; this information is +# used to generate a meaningful time axis. events_path = sample_dir / 'sample_audvis_filt-0-40_raw-eve.fif' events = mne.find_events(raw=raw)