Skip to content

Commit

Permalink
Merge pull request #2931 from ericpre/fix_overlay_plot_image
Browse files Browse the repository at this point in the history
Fix overlay plot image
  • Loading branch information
ericpre committed Apr 25, 2022
2 parents c1428da + 4bda80f commit 9394850
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 64 deletions.
3 changes: 3 additions & 0 deletions conda_environment_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ dependencies:
# We pin freetype and matplotlib for the image comparison
- freetype=2.10
- matplotlib-base=3.5.1
# temporary pinning until the conda-forge pinning is fixed
# tifffile 2022.2.2 does support for python 3.7
- tifffile <2022.4.22
- cython
- pytest
- pytest-mpl
Expand Down
153 changes: 89 additions & 64 deletions hyperspy/drawing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# along with HyperSpy. If not, see <https://www.gnu.org/licenses/#GPL>.

import copy
from functools import partial
import itertools
import logging
from packaging.version import Version
Expand All @@ -25,14 +26,13 @@

import dask.array as da
import traits.api as t
import matplotlib.pyplot as plt
import numpy as np

import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.backend_bases import key_press_handler
import numpy as np
from functools import partial
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import BASE_COLORS, to_rgba
from matplotlib.colors import LinearSegmentedColormap, BASE_COLORS, to_rgba
import matplotlib.pyplot as plt

from hyperspy.defaults_parser import preferences
from hyperspy.misc.utils import to_numpy
Expand Down Expand Up @@ -429,12 +429,11 @@ def _make_cascade_subplot(spectra, ax, color, linestyle, padding=1, **kwargs):
np.nanmin(spectrum.data))
if spectrum_yrange > max_value:
max_value = spectrum_yrange
for spectrum_index, (spectrum, color, linestyle) in enumerate(
for i, (spectrum, color, linestyle) in enumerate(
zip(spectra, color, linestyle)):
x_axis = spectrum.axes_manager.signal_axes[0]
spectrum = _transpose_if_required(spectrum, 1)
data_to_plot = to_numpy((spectrum.data - spectrum.data.min()) /
float(max_value) + spectrum_index * padding)
data = _parse_array(_transpose_if_required(spectrum, 1))
data_to_plot = (data - data.min()) / float(max_value) + i * padding
ax.plot(x_axis.axis, data_to_plot, color=color, ls=linestyle,
**kwargs)
set_xaxis_lims(ax, x_axis)
Expand Down Expand Up @@ -482,7 +481,7 @@ def plot_images(images,
labelwrap=30,
suptitle=None,
suptitle_fontsize=18,
colorbar='multi',
colorbar='default',
centre_colormap='auto',
scalebar=None,
scalebar_color='white',
Expand All @@ -500,7 +499,7 @@ def plot_images(images,
alphas=1.0,
legend_picking=True,
legend_loc='upper right',
pixel_size_factor=1,
pixel_size_factor=None,
**kwargs):
"""Plot multiple images either as sub-images or overlayed in one figure.
Expand Down Expand Up @@ -551,13 +550,15 @@ def plot_images(images,
this parameter will override the automatically determined title.
suptitle_fontsize : int, optional
Font size to use for super title at top of figure.
colorbar : 'multi', None, 'single', optional
Controls the type of colorbars that are plotted.
If None, no colorbar is plotted.
If 'multi' (default), individual colorbars are plotted for each
(non-RGB) image
colorbar : 'default', 'multi', 'single', None, optional
Controls the type of colorbars that are plotted, incompatible with
``overlay=True``.
If 'default', same as 'multi' when ``overlay=False``, otherwise same
as ``None``.
If 'multi', individual colorbars are plotted for each (non-RGB) image.
If 'single', all (non-RGB) images are plotted on the same scale,
and one colorbar is shown for all
and one colorbar is shown for all.
If None, no colorbar is plotted.
centre_colormap : 'auto', True, False, optional
If True, the centre of the color scheme is set to zero. This is
particularly useful when using diverging color schemes. If 'auto'
Expand Down Expand Up @@ -603,7 +604,7 @@ def plot_images(images,
values will require more overlap in titles before activing the
auto-label code.
fig : mpl figure, optional
If set, the images will be plotted to an existing MPL figure
If set, the images will be plotted to an existing matplotlib figure.
vmin, vmax: scalar, str, None
If str, formatted as 'xth', use this value to calculate the percentage
of pixels that are left out of the lower and upper bounds.
Expand All @@ -625,15 +626,17 @@ def plot_images(images,
Float value or a list of floats corresponding to the alpha value of
each color.
legend_picking: bool, optional
If True (default), a spectrum can be toggled on and off by clicking on
the legended line.
If True (default), an image can be toggled on and off by clicking on
the legended line. For ``overlay=True`` only.
legend_loc : str, int, optional
This parameter controls where the legend is placed on the figure
see the :py:func:`matplotlib.pyplot.legend` docstring for valid values
pixel_size_factor : int or float, optional
Default value is 1. Sets the size of the figure when plotting an overlay image. The higher
the number the larger the figure and therefore a greater number of
pixels are used. This value will be ignored if a Figure is provided.
pixel_size_factor : None, int or float, optional
If ``None`` (default), the size of the figure is taken from the
matplotlib ``rcParams``. Otherwise sets the size of the figure when
plotting an overlay image. The higher the number the larger the figure
and therefore a greater number of pixels are used. This value will be
ignored if a Figure is provided.
**kwargs, optional
Additional keyword arguments passed to :py:func:`matplotlib.pyplot.imshow`.
Expand Down Expand Up @@ -704,6 +707,15 @@ def __check_single_colorbar(cbar):
if sig.axes_manager.navigation_size > 0
else 1)

# Check compatibility of colorbar and overlay arguments
if overlay and colorbar != 'default':
_logger.info(f"`colorbar='{colorbar}'` is incompatible with "
"`overlay=True`. Colorbar is disable.")
colorbar = None
# Setting the default value
elif colorbar == 'default':
colorbar = 'multi'

# If no cmap given, get default colormap from pyplot:
if cmap is None:
cmap = [preferences.Plot.cmap_signal]
Expand Down Expand Up @@ -841,22 +853,52 @@ def __check_single_colorbar(cbar):
else:
raise ValueError("Did not understand input of labels.")

# Start of non-overlay?
# Check if we need to add a scalebar for some of the images
if isinstance(scalebar, (list, tuple)) and all(isinstance(x, int)
for x in scalebar):
scalelist = True
else:
scalelist = False

if scalebar not in [None, False, 'all'] and scalelist is False:
raise ValueError("Did not understand scalebar input. Must be None, "
"'all', or list of ints.")

# Determine appropriate number of images per row
rows = int(np.ceil(n / float(per_row)))
if n < per_row:
per_row = n
if overlay:
# only a single image
per_row = rows = 1
else:
rows = int(np.ceil(n / float(per_row)))
if n < per_row:
per_row = n

# Set overall figure size and define figure (if not pre-existing)
if fig is None:
k = max(plt.rcParams['figure.figsize']) / max(per_row, rows)
if overlay:
shape = images[0].data.shape
dpi = 100
f = plt.figure(figsize=[pixel_size_factor*v/dpi for v in shape],
dpi=dpi)
w, h = plt.rcParams['figure.figsize']
dpi = plt.rcParams['figure.dpi']
if overlay and axes_decor == 'off':
shape = images[0].axes_manager.signal_shape
if pixel_size_factor is None:
# Cap the maximum dimension of figure to
# plt.rcParams['figure.figsize']
aspect_ratio = shape[0] / shape[1]
if aspect_ratio >= w / h:
if label is not None and w / aspect_ratio < 1.0:
# Needs enough space for the labels
w = 1.0 * aspect_ratio
figsize = (w, w / aspect_ratio)
else:
if scalebar is not None and h * aspect_ratio < 2.0:
# Needs enough width for the scalebar
h = 2.0 / aspect_ratio
figsize = (h * aspect_ratio, h)
else:
figsize = [pixel_size_factor*v/dpi for v in shape]
else:
f = plt.figure(figsize=(tuple(k * i for i in (per_row, rows))))
k = max(w, h) / max(per_row, rows)
figsize=[k * i for i in (per_row, rows)]
f = plt.figure(figsize=figsize, dpi=dpi)
else:
f = fig

Expand All @@ -878,12 +920,7 @@ def __check_single_colorbar(cbar):
colorbar = None
warnings.warn("Sorry, colorbar is not implemented for RGB images.")

# Check if we need to add a scalebar for some of the images
if isinstance(scalebar, list) and all(isinstance(x, int)
for x in scalebar):
scalelist = True
else:
scalelist = False


def check_list_length(arg, arg_name):
if isinstance(arg, (list, tuple)):
Expand Down Expand Up @@ -934,16 +971,16 @@ def transparent_single_color_cmap(color):
images[0].axes_manager[0].scale):
raise ValueError("Images are not the same scale and so should"
"not be overlayed.")

_logger.warning('vmin is ignored when overlaying images.')

if vmin is not None:
_logger.warning('`vmin` is ignored when overlaying images.')

import matplotlib.patches as mpatches
factor = plt.rcParams['font.size'] / 100
if not suptitle and axes_decor == 'off':
ax = f.add_axes([0, 0, 1, 1])
elif not suptitle:
ax = f.add_axes([0.1, 0.1, 1, 1])
else:
ax = f.add_axes([0.1, 0.1, 0.9, 0.8])
ax = f.add_subplot()
patches = []

#If no colors are selected use BASE_COLORS
Expand Down Expand Up @@ -988,12 +1025,12 @@ def transparent_single_color_cmap(color):

if label is not None:
plt.legend(handles=patches, loc=legend_loc)
if legend_picking is True:
if legend_picking:
animate_legend(fig=f, ax=ax, plot_type='images')

set_axes_decor(ax, axes_decor)

if scalebar=='all':
if scalebar == 'all':
axes = im.axes_manager.signal_axes
ax.scalebar = ScaleBar(
ax=ax,
Expand All @@ -1003,7 +1040,7 @@ def transparent_single_color_cmap(color):
axes_list.append(ax)

#Below is for non-overlayed images
if not overlay:
else:
# Loop through each image, adding subplot for each one
for i, ims in enumerate(images):
# Get handles for the signal axes and axes_manager
Expand Down Expand Up @@ -1159,20 +1196,6 @@ def transparent_single_color_cmap(color):
f.subplots_adjust(top=0.85)
f.suptitle(suptitle, fontsize=suptitle_fontsize)

# If we want to plot scalebars, loop through the list of axes and add them
if scalebar is None or scalebar is False:
# Do nothing if no scalebars are called for
pass
elif scalebar == 'all':
# scalebars were taken care of in the plotting loop
pass
elif scalelist:
# scalebars were taken care of in the plotting loop
pass
else:
raise ValueError("Did not understand scalebar input. Must be None, "
"'all', or list of ints.")

# Adjust subplot spacing according to user's specification
if padding is not None:
plt.subplots_adjust(**padding)
Expand Down Expand Up @@ -1526,7 +1549,8 @@ def update_line(spectrum, line):
def animate_legend(fig=None, ax=None, plot_type='spectra'):
"""Animate the legend of a figure.
A spectrum can be toggled on and off by clicking on the line in the legend.
A spectrum or image can be toggled on and off by clicking on the line in
the legend.
Parameters
----------
Expand Down Expand Up @@ -1664,3 +1688,4 @@ def picker_kwargs(value, kwargs={}):
kwargs['picker'] = value

return kwargs

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 76 additions & 0 deletions hyperspy/tests/drawing/test_plot_signal2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# You should have received a copy of the GNU General Public License
# along with HyperSpy. If not, see <https://www.gnu.org/licenses/#GPL>.

import logging

import matplotlib.pyplot as plt
import numpy as np
import pytest
Expand Down Expand Up @@ -589,6 +591,7 @@ def test_plot_autoscale_data_changed(autoscale):
np.testing.assert_allclose(imf._vmin, _vmin)
np.testing.assert_allclose(imf._vmax, _vmax)


@pytest.mark.parametrize("axes_decor", ['all', 'off'])
@pytest.mark.parametrize("label", ['auto', ['b','g']])
@pytest.mark.parametrize("colors", ['auto', ['b','g']])
Expand Down Expand Up @@ -616,3 +619,76 @@ def test_plot_scale_different_sign():
s2.plot()
assert s2._plot.signal_plot.pixel_units is not None
assert s2._plot.signal_plot.scalebar is True


def test_plot_images_overlay_colorbar():
s = hs.signals.Signal2D(np.arange(100).reshape(10, 10))
hs.plot.plot_images([s, s], overlay=True, colorbar='single',
axes_decor='off')


def test_plot_images_overlay_aspect_ratio():
s = hs.signals.Signal2D(np.arange(100).reshape(2, 50))
hs.plot.plot_images([s, s], overlay=True, axes_decor='off')
f = plt.gcf()
np.testing.assert_allclose((f.get_figwidth(), f.get_figheight()), (25.0, 1.0))

s = hs.signals.Signal2D(np.arange(100).reshape(20, 5))
hs.plot.plot_images([s, s], overlay=True, axes_decor='off', scalebar='all')
f = plt.gcf()
np.testing.assert_allclose((f.get_figwidth(), f.get_figheight()), (2.0, 8.0))


def test_plot_images_overlay_figsize():
"""Test figure size for different aspect ratio of image."""
# Set reference figure size
plt.rcParams['figure.figsize'] = [6.4, 4.8]

# aspect_ratio is 1
s = hs.signals.Signal2D(np.random.random((10, 10)))
hs.plot.plot_images([s, s], overlay=True, scalebar='all', axes_decor='off')
f = plt.gcf()
np.testing.assert_allclose((f.get_figwidth(), f.get_figheight()), (4.8, 4.8))

# aspect_ratio is 64 / 48
s = hs.signals.Signal2D(np.random.random((48, 64)))
hs.plot.plot_images([s, s], overlay=True, scalebar='all', axes_decor='off')
f = plt.gcf()
np.testing.assert_allclose((f.get_figwidth(), f.get_figheight()), (6.4, 4.8))

# aspect_ratio is 2
s = hs.signals.Signal2D(np.random.random((10, 20)))
hs.plot.plot_images([s, s], overlay=True, scalebar='all', axes_decor='off')
f = plt.gcf()
np.testing.assert_allclose((f.get_figwidth(), f.get_figheight()), (6.4, 3.2))

# aspect_ratio is 0.5
s = hs.signals.Signal2D(np.random.random((20, 10)))
hs.plot.plot_images([s, s], overlay=True, scalebar='all', axes_decor='off')
f = plt.gcf()
np.testing.assert_allclose((f.get_figwidth(), f.get_figheight()), (2.4, 4.8))


def test_plot_images_overlay_vmin_warning(caplog):
s = hs.signals.Signal2D(np.arange(100).reshape(10, 10))
with caplog.at_level(logging.WARNING):
hs.plot.plot_images([s, s], overlay=True, vmin=0)

assert "`vmin` is ignored when overlaying images." in caplog.text


def test_plot_scalebar_error():
s = hs.signals.Signal2D(np.arange(100).reshape(10, 10))
with pytest.raises(ValueError):
hs.plot.plot_images([s, s], scalebar='unsupported_argument')


def test_plot_scalebar_list():
s = hs.signals.Signal2D(np.arange(100).reshape(10, 10))
ax0, ax1 = hs.plot.plot_images([s, s], scalebar=[0, 1])
assert hasattr(ax0, 'scalebar')
assert hasattr(ax1, 'scalebar')

ax0, ax1 = hs.plot.plot_images([s, s], scalebar=[0])
assert hasattr(ax0, 'scalebar')
assert not hasattr(ax1, 'scalebar')

0 comments on commit 9394850

Please sign in to comment.