Skip to content

Commit

Permalink
ENH: Add webp support to Report (#11359)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Dec 7, 2022
1 parent 1be103d commit 8f6af80
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 24 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Expand Up @@ -28,6 +28,7 @@ Enhancements
- Add a warning to the docstring of :func:`mne.channels.find_ch_adjacency` to encourage users to validate their outputs (:gh:`11236` by `Felix Klotzsche`_ and `Eric Larson`_)
- Mixed, cortical + discrete source spaces with fixed orientations are now allowed. (:gh:`11241` by `Jevri Hanna`_)
- Add size information to the ``repr`` of :class:`mne.Report` (:gh:`11357` by `Eric Larson`_)
- Add support for ``image_format='webp'`` to :class:`mne.Report` when using Matplotlib 3.6+, which can reduce file sizes by up to 50% compared to ``'png'``. The new default ``image_format='auto'`` will automatically use this format if it's available on the system (:gh:`11359` by `Eric Larson`_)
- Add :func:`mne.beamformer.apply_dics_tfr_epochs` to apply a DICS beamformer to time-frequency resolved epochs (:gh:`11096` by `Alex Rockhill`_)
- Check whether head radius (estimated from channel positions) is correct when reading EEGLAB data with :func:`~mne.io.read_raw_eeglab` and :func:`~mne.read_epochs_eeglab`. If head radius is not within likely values, warn informing about possible units mismatch and the new ``montage_units`` argument (:gh:`11283` by `Mikołaj Magnuski`_).
- Add support for a callable passed in ``combine`` for `mne.time_frequency.AverageTFR.plot` and `mne.time_frequency.AverageTFR.plot_joint` (:gh:`11329` by `Mathieu Scheltienne`_)
Expand Down Expand Up @@ -55,6 +56,7 @@ Bugs
- Fix bug with :class:`mne.Report` with ``replace=True`` where the wrong content was replaced and ``section`` was not respected (:gh:`11318`, :gh:`11346` by `Eric Larson`_)
- Fix bug with unit conversion when setting reference MEG as the channel type in :meth:`mne.io.Raw.set_channel_types` and related methods (:gh:`11344` by `Eric Larson`_)
- Fix bug where reference MEG channels could not be plotted using :func:`mne.viz.plot_epochs_image` (:gh:`11344` by `Eric Larson`_)
- Fix bug where ``image_format='gif'`` was errantly documented as being supported by :class:`mne.Report`, it is now only supported in :meth:`mne.Report.add_image` (:gh:`11347` by `Eric Larson`_)
- Multitaper spectral estimation now uses periodic (rather than symmetric) taper windows. This also necessitated changing the default ``max_iter`` of our cross-spectral density functions from 150 to 250. (:gh:`11293` by `Daniel McCloy`_)
- Fix :meth:`mne.Epochs.plot_image` and :func:`mne.viz.plot_epochs_image` when using EMG signals (:gh:`11322` by `Alex Gramfort`_)

Expand Down
80 changes: 61 additions & 19 deletions mne/report/report.py
Expand Up @@ -26,7 +26,6 @@
import numpy as np

from .. import __version__ as MNE_VERSION
from ..fixes import _compare_version
from .. import (read_evokeds, read_events, read_cov,
read_source_estimate, read_trans, sys_info,
Evoked, SourceEstimate, Covariance, Info, Transform)
Expand All @@ -40,7 +39,8 @@
from ..utils import (logger, verbose, get_subjects_dir, warn, _ensure_int,
fill_doc, _check_option, _validate_type, _safe_input,
_path_like, use_log_level, _check_fname, _pl,
_check_ch_locs, _import_h5io_funcs, _verbose_safe_false)
_check_ch_locs, _import_h5io_funcs, _verbose_safe_false,
check_version)
from ..viz import (plot_events, plot_alignment, plot_cov, plot_projs_topomap,
plot_compare_evokeds, set_3d_view, get_3d_backend,
Figure3D, use_browser_backend)
Expand Down Expand Up @@ -345,21 +345,46 @@ def _fig_to_img(fig, *, image_format='png', own_figure=True):
own_figure = True # close the fig we just created

output = BytesIO()
dpi = fig.get_dpi()
logger.debug(
f'Saving figure with dimension {fig.get_size_inches()} inches with '
f'{fig.get_dpi()} dpi'
f'{{dpi}} dpi'
)

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html
mpl_kwargs = dict()
pil_kwargs = dict()
has_pillow = check_version('PIL')
if has_pillow:
if image_format == 'webp':
pil_kwargs.update(lossless=True, method=6)
elif image_format == 'png':
pil_kwargs.update(optimize=True, compress_level=9)
if pil_kwargs:
# matplotlib modifies the passed dict, which is a bug
mpl_kwargs['pil_kwargs'] = pil_kwargs.copy()
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
message='.*Axes that are not compatible with tight_layout.*',
category=UserWarning
)
fig.savefig(output, format=image_format, dpi=fig.get_dpi())
fig.savefig(output, format=image_format, dpi=dpi, **mpl_kwargs)

if own_figure:
plt.close(fig)

# Remove alpha
if image_format != 'svg' and has_pillow:
from PIL import Image
output.seek(0)
orig = Image.open(output)
if orig.mode == 'RGBA':
background = Image.new('RGBA', orig.size, (255, 255, 255))
new = Image.alpha_composite(background, orig).convert('RGB')
output = BytesIO()
new.save(output, format=image_format, dpi=(dpi, dpi), **pil_kwargs)

output = output.getvalue()
return (output.decode('utf-8') if image_format == 'svg' else
base64.b64encode(output).decode('ascii'))
Expand Down Expand Up @@ -405,9 +430,8 @@ def _get_bem_contour_figs_as_arrays(
A list of NumPy arrays that represent the generated Matplotlib figures.
"""
# Matplotlib <3.2 doesn't work nicely with process-based parallelization
from matplotlib import __version__ as MPL_VERSION
kwargs = dict()
if _compare_version(MPL_VERSION, '<', '3.2'):
if not check_version('matplotilb', '3.2'):
kwargs['prefer'] = 'threads'

parallel, p_fun, n_jobs = parallel_func(
Expand Down Expand Up @@ -598,6 +622,16 @@ def open_report(fname, **params):
mne_logo_path = Path(__file__).parents[1] / 'icons' / 'mne_icon-cropped.png'
mne_logo = base64.b64encode(mne_logo_path.read_bytes()).decode('ascii')

_ALLOWED_IMAGE_FORMATS = ('png', 'svg', 'webp')


def _webp_supported():
good = check_version('matplotlib', '3.6') and check_version('PIL')
if good:
from PIL import features
good = features.check('webp')
return good


def _check_scale(scale):
"""Ensure valid scale value is passed."""
Expand All @@ -608,10 +642,17 @@ def _check_scale(scale):
def _check_image_format(rep, image_format):
"""Ensure fmt is valid."""
if rep is None or image_format is not None:
_check_option('image_format', image_format,
allowed_values=('png', 'svg', 'gif'))
allowed = list(_ALLOWED_IMAGE_FORMATS) + ['auto']
extra = ''
if not _webp_supported():
allowed.pop(allowed.index('webp'))
extra = '("webp" supported on matplotlib 3.6+ with PIL installed)'
_check_option(
'image_format', image_format, allowed_values=allowed, extra=extra)
else:
image_format = rep.image_format
if image_format == 'auto':
image_format = 'webp' if _webp_supported() else 'png'
return image_format


Expand All @@ -632,13 +673,17 @@ class Report:
Name of the file containing the noise covariance.
%(baseline_report)s
Defaults to ``None``, i.e. no baseline correction.
image_format : 'png' | 'svg' | 'gif'
Default image format to use (default is ``'png'``).
image_format : 'png' | 'svg' | 'webp' | 'auto'
Default image format to use (default is ``'auto'``, which will use
``'webp'`` if available and ``'png'`` otherwise).
``'svg'`` uses vector graphics, so fidelity is higher but can increase
file size and browser image rendering time as well.
``'webp'`` format requires matplotlib >= 3.6.
.. versionadded:: 0.15
.. versionchanged:: 1.3
Added support for ``'webp'`` format, removed support for GIF, and
set the default to ``'auto'``.
raw_psd : bool | dict
If True, include PSD plots for raw files. Can be False (default) to
omit, True to plot, or a dict to pass as ``kwargs`` to
Expand Down Expand Up @@ -669,7 +714,6 @@ class Report:
Default image format to use.
.. versionadded:: 0.15
raw_psd : bool | dict
If True, include PSD plots for raw files. Can be False (default) to
omit, True to plot, or a dict to pass as ``kwargs`` to
Expand Down Expand Up @@ -703,7 +747,7 @@ class Report:
@verbose
def __init__(self, info_fname=None, subjects_dir=None,
subject=None, title=None, cov_fname=None, baseline=None,
image_format='png', raw_psd=False, projs=False, *,
image_format='auto', raw_psd=False, projs=False, *,
verbose=None):
self.info_fname = str(info_fname) if info_fname is not None else None
self.cov_fname = str(cov_fname) if cov_fname is not None else None
Expand Down Expand Up @@ -2039,14 +2083,11 @@ def add_image(
.. versionadded:: 0.24.0
"""
tags = _check_tags(tags)
img_bytes = Path(image).expanduser().read_bytes()
img_base64 = base64.b64encode(img_bytes).decode('ascii')
del img_bytes # Free memory

image = Path(_check_fname(image, overwrite='read', must_exist=True))
img_format = Path(image).suffix.lower()[1:] # omit leading period
_check_option('Image format', value=img_format,
allowed_values=('png', 'gif', 'svg'))

allowed_values=list(_ALLOWED_IMAGE_FORMATS) + ['gif'])
img_base64 = base64.b64encode(image.read_bytes()).decode('ascii')
self._add_image(
img=img_base64,
title=title,
Expand Down Expand Up @@ -2156,6 +2197,7 @@ def _render_slider(
imgs = [_fig_to_img(fig=fig, image_format=image_format,
own_figure=own_figure)
for fig in figs]

dom_id = self._get_dom_id()
html = _html_slider_element(
id=dom_id,
Expand Down
45 changes: 42 additions & 3 deletions mne/report/tests/test_report.py
Expand Up @@ -22,7 +22,9 @@
from mne import (Epochs, read_events, read_evokeds, read_cov,
pick_channels_cov, create_info)
from mne.report import report as report_mod
from mne.report.report import CONTENT_ORDER
from mne.report.report import (
CONTENT_ORDER, _ALLOWED_IMAGE_FORMATS, _webp_supported,
)
from mne.io import read_raw_fif, read_info, RawArray
from mne.datasets import testing
from mne.report import Report, open_report, _ReportScraper, report
Expand Down Expand Up @@ -394,7 +396,7 @@ def test_render_add_sections(renderer, tmp_path):
fig.savefig(img_fname)
report.add_image(image=img_fname, title='evoked response')

with pytest.raises(FileNotFoundError, match='No such file or directory'):
with pytest.raises(FileNotFoundError, match='does not exist'):
report.add_image(image='foobar.xxx', title='H')

evoked = read_evokeds(evoked_fname, condition='Left Auditory',
Expand Down Expand Up @@ -452,7 +454,7 @@ def test_add_bem_n_jobs(n_jobs, monkeypatch):
use_subjects_dir = None
else:
use_subjects_dir = subjects_dir
report = Report(subjects_dir=use_subjects_dir)
report = Report(subjects_dir=use_subjects_dir, image_format='png')
# implicitly test that subjects_dir is correctly preserved here
monkeypatch.setattr(report_mod, '_BEM_VIEWS', ('axial',))
if use_subjects_dir is not None:
Expand Down Expand Up @@ -996,3 +998,40 @@ def test_tags(tags, str_or_array, wrong_dtype, invalid_chars):
r.add_code(code='foo', title='bar', tags=tags)
else:
r.add_code(code='foo', title='bar', tags=tags)


# These are all the ones we claim to support
@pytest.mark.parametrize('image_format', _ALLOWED_IMAGE_FORMATS)
def test_image_format(image_format):
"""Test image format support."""
if image_format == 'webp':
if not _webp_supported():
with pytest.raises(ValueError, match='matplotlib'):
Report(image_format='webp')
return
r = Report(image_format=image_format)
fig1, _ = _get_example_figures()
r.add_figure(fig1, 'fig1')
assert image_format in r.html[0]


def test_gif(tmp_path):
"""Test that GIFs can be embedded using add_image."""
pytest.importorskip('PIL')
from PIL import Image
sequence = [
Image.fromarray(frame.astype(np.uint8))
for frame in _get_example_figures()
]
fname = tmp_path / 'test.gif'
sequence[0].save(str(fname), save_all=True, append_images=sequence[1:])
assert fname.is_file()
with pytest.raises(ValueError, match='Allowed values'):
Report(image_format='gif')
r = Report()
r.add_image(fname, 'fname')
assert 'image/gif' in r.html[0]
bad_name = fname.with_suffix('.foo')
bad_name.write_bytes(b'')
with pytest.raises(ValueError, match='Allowed values'):
r.add_image(bad_name, 'fname')
4 changes: 2 additions & 2 deletions tutorials/intro/70_report.py
Expand Up @@ -101,8 +101,8 @@
# ^^^^^^^^^^^^^
#
# Events can be added via :meth:`mne.Report.add_events`. You also need to
# supply the sampling frequency used during the recording; this information
# is used to generate a meaningful time axis.
# supply the sampling frequency used during the recording; this information is
# used to generate a meaningful time axis.

events_path = sample_dir / 'sample_audvis_filt-0-40_raw-eve.fif'
events = mne.find_events(raw=raw)
Expand Down

0 comments on commit 8f6af80

Please sign in to comment.