Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Feature: Add sub-figure plotting #3343

Open
wants to merge 16 commits into
base: RELEASE_next_minor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/plotting/README.rst
@@ -0,0 +1,6 @@
.. _subfigure_examples_label:

Making Custom Layouts for Plots
===============================
ericpre marked this conversation as resolved.
Show resolved Hide resolved

Below is a gallery of examples on making simple custom layouts for plotting data.
83 changes: 83 additions & 0 deletions examples/plotting/ROI_insets.py
@@ -0,0 +1,83 @@
"""
==========
ROI Insets
==========

ROI's can be powerful tools to help visualize data. In this case we will define ROI's in hyperspy, sum
the data within the ROI, and then plot the sum as a signal. Using the :class:`matplotlib.figure.SubFigure` class
we can create a custom layout to visualize and interact with the data.

We can connect these ROI's using the :func:`hyperspy.api.interactive` function which allows us to move the ROI's and see the sum of the underlying data.
"""
import matplotlib.pyplot as plt
import hyperspy.api as hs
import numpy as np

rng = np.random.default_rng()

fig = plt.figure(figsize=(5, 3))
gs = fig.add_gridspec(6, 10)
sub1 = fig.add_subfigure(gs[0:6, 0:6])
sub2 = fig.add_subfigure(gs[0:2, 6:8])
sub3 = fig.add_subfigure(gs[2:4, 7:9])
sub4 = fig.add_subfigure(gs[4:6, 6:8])
Comment on lines +19 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the grid to make easier to understand (and also when looking at the image, the reason for this layout is not obvious)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericpre that is a good point. I might have to play around with this one slightly and make some fake data. I use something like this to show average diffraction patterns from ROI's in real space.

Here's a gif which is similar to the second example.

ZrCuTiAl-Crystalization



s = hs.signals.Signal2D(rng.random((10, 10, 30, 30)))
r1 = hs.roi.RectangularROI(1, 1, 3, 3)
r2 = hs.roi.RectangularROI(4, 4, 6, 6)
r3 = hs.roi.RectangularROI(3, 7, 5, 9)

navigator = s.sum(axis=(2, 3)).T # create a navigator signal
navigator.plot(fig=sub1, colorbar=False, axes_off=True, title="", plot_indices=False)


s2 = r1.interactive(s, navigation_signal=navigator, color="red")
s3 = r2.interactive(s, navigation_signal=navigator, color="g")
s4 = r3.interactive(s, navigation_signal=navigator, color="y")

s2_int = s2.sum()
s3_int = s3.sum()
s4_int = s4.sum()

s2_int.plot(fig=sub2, colorbar=False, axes_off=True, title="", plot_indices=False)
s3_int.plot(fig=sub3, colorbar=False, axes_off=True, title="", plot_indices=False)
s4_int.plot(fig=sub4, colorbar=False, axes_off=True, title="", plot_indices=False)

# Connect ROIS
for s, s_int, roi in zip([s2, s3, s4], [s2_int, s3_int, s4_int],[r1,r2,r3]):
hs.interactive(
s.sum,
event=roi.events.changed,
recompute_out_event=None,
out=s_int,
)

# Add Borders to the match the color of the ROI

for signal,color, label in zip([s2_int, s3_int, s4_int], ["r", "g", "y"], ["b.", "c.", "d."]):
edge = hs.plot.markers.Squares(
offset_transform="axes",
offsets=(0.5, 0.5),
units="width",
widths=1,
color=color,
linewidth=5,
facecolor="none",
)

signal.add_marker(edge)

label = hs.plot.markers.Texts(
texts=(label,), offsets=[[0.85, 0.85]], offset_transform="axes", sizes=2, color="w"
)
signal.add_marker(label)

# Label the big plot

label = hs.plot.markers.Texts(
texts=("a.",), offsets=[[0.9, 0.9]], offset_transform="axes", sizes=10, color="w"
)
navigator.add_marker(label)

# %%
37 changes: 37 additions & 0 deletions examples/plotting/custom_figure_layout.py
@@ -0,0 +1,37 @@
"""
=======================
Creating Custom Layouts
=======================

Custom layouts for hyperspy figures can be created using the :class:`matplotlib.figure.SubFigure` class. Passing
the ``fig`` argument to the :meth:`~.api.signals.BaseSignal.plot` method of a hyperspy signal object will target
that figure instead of creating a new one. This is useful for creating custom layouts with multiple subplots.
"""

# Creating a simple layout with two subplots

import matplotlib.pyplot as plt
import hyperspy.api as hs
import numpy as np

rng = np.random.default_rng()
s = hs.signals.Signal2D(rng.random((10, 10, 10, 10)))
fig = plt.figure(figsize=(10, 5), layout="constrained")
subfigs = fig.subfigures(1, 2, wspace=0.07)
s.plot(navigator_kwds=dict(fig=subfigs[0]), fig=subfigs[1])

# %%

# Sharing a navigator between two hyperspy signals

s = hs.signals.Signal2D(rng.random((10, 10, 10, 10)))
s2 = hs.signals.Signal2D(rng.random((10, 10, 50, 50)))

fig = plt.figure(figsize=(8, 7), layout="constrained")
head_figures = fig.subfigures(1, 2, wspace=0.07)
signal_figures = head_figures[1].subfigures(2, 1, hspace=0.07)
s.plot(navigator_kwds=dict(fig=head_figures[0], colorbar=None), fig=signal_figures[0])
s2.plot(navigator=None, fig=signal_figures[1], axes_manager=s.axes_manager)

# %%
# sphinx_gallery_thumbnail_number = 2
1 change: 1 addition & 0 deletions hyperspy/conftest.py
Expand Up @@ -54,6 +54,7 @@
hs.preferences.Plot.cmap_navigator = "viridis"
hs.preferences.Plot.cmap_signal = "viridis"
hs.preferences.Plot.pick_tolerance = 5.0
hs.preferences.Plot.use_subfigure = False
# Don't show progressbar since it contains the runtime which
# will make the doctest fail
hs.preferences.General.show_progressbar = False
Expand Down
6 changes: 6 additions & 0 deletions hyperspy/defaults_parser.py
Expand Up @@ -120,6 +120,12 @@ class PlotConfig(t.HasTraits):
widget_plot_style = t.Enum(
["horizontal", "vertical"], label="Widget plot style: (only with ipympl)"
)
use_subfigure = t.CBool(
False,
desc="EXPERIMENTAL. Plot navigator and signal on the same figure. "
"Note that this is slower than using separate figures "
"and it requires matplotlib >=3.9.",
)
cmap_navigator = t.Str(
"gray",
label="Color map navigator",
Expand Down
43 changes: 31 additions & 12 deletions hyperspy/drawing/figure.py
Expand Up @@ -19,6 +19,7 @@
import logging
import textwrap

import matplotlib
import matplotlib.pyplot as plt

from hyperspy.drawing import utils
Expand All @@ -27,7 +28,7 @@
_logger = logging.getLogger(__name__)


class BlittedFigure(object):
class BlittedFigure:
def __init__(self):
self._draw_event_cid = None
self._background = None
Expand All @@ -43,22 +44,30 @@ def __init__(self):
""",
arguments=["obj"],
)
# The matplotlib Figure or SubFigure
# To access the matplotlib figure, use `get_mpl_figure`
self.figure = None
# The matplotlib Axis
self.ax = None
self.title = ""
self.ax_markers = list()

def create_figure(self, **kwargs):
"""Create matplotlib figure
"""
Create matplotlib figure.

Parameters
----------
**kwargs
All keyword arguments are passed to ``plt.figure``.
**kwargs : dict
Keyword arguments are passed to
:func:`hyperspy.drawing.utils.create_figure`.

"""
kwargs.setdefault("_on_figure_window_close", self.close)
self.figure = utils.create_figure(
window_title="Figure " + self.title if self.title else None, **kwargs
window_title="Figure " + self.title if self.title else None,
**kwargs,
)
utils.on_figure_window_close(self.figure, self._on_close)
if self.figure.canvas.supports_blit:
self._draw_event_cid = self.figure.canvas.mpl_connect(
"draw_event", self._on_blit_draw
Expand Down Expand Up @@ -92,6 +101,19 @@ def _update_animated(self):
self._draw_animated()
canvas.blit(self.figure.bbox)

def get_mpl_figure(self):
"""Retuns the matplotlib figure"""
if self.figure is None:
return None
else:
# See https://github.com/matplotlib/matplotlib/pull/28177
figure = self.figure
# matplotlib SubFigure can be nested and we don't support it
if isinstance(figure, matplotlib.figure.SubFigure):
return figure.figure
else:
return figure

def add_marker(self, marker):
marker.ax = self.ax
self.ax_markers.append(marker)
Expand All @@ -106,9 +128,8 @@ def remove_markers(self, render_figure=False):

def _on_close(self):
_logger.debug("Closing `BlittedFigure`.")
if self.figure is None:
_logger.debug("`BlittedFigure` already closed.")
return # Already closed
self.ax = None
self._background = None
for marker in self.ax_markers:
marker.close(render_figure=False)
self.events.closed.trigger(obj=self)
Expand All @@ -117,15 +138,13 @@ def _on_close(self):
if self._draw_event_cid:
self.figure.canvas.mpl_disconnect(self._draw_event_cid)
self._draw_event_cid = None
plt.close(self.figure)
self.figure = None
self.ax = None
self._background = None
_logger.debug("`BlittedFigure` closed.")

def close(self):
_logger.debug("`close` `BlittedFigure` called.")
self._on_close() # Needs to trigger serially for a well defined state
plt.close(self.get_mpl_figure())

@property
def title(self):
Expand Down
9 changes: 7 additions & 2 deletions hyperspy/drawing/image.py
Expand Up @@ -25,6 +25,7 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm, Normalize, PowerNorm, SymLogNorm
from matplotlib.figure import SubFigure
from packaging.version import Version
from rsciio.utils import rgb_tools
from traits.api import Undefined
Expand Down Expand Up @@ -330,7 +331,9 @@ def plot(self, data_function_kwargs={}, **kwargs):
self.data_function_kwargs = data_function_kwargs
self.configure()
if self.figure is None:
self.create_figure()
fig = kwargs.pop("fig", None)
_on_figure_window_close = kwargs.pop("_on_figure_window_close", None)
self.create_figure(fig=fig, _on_figure_window_close=_on_figure_window_close)
self.create_axis()

if not self.axes_manager or self.axes_manager.navigation_size == 0:
Expand Down Expand Up @@ -382,7 +385,7 @@ def _add_colorbar(self):
# Bug extend='min' or extend='both' and power law norm
# Use it when it is fixed in matplotlib
ims = self.ax.images if len(self.ax.images) else self.ax.collections
self._colorbar = plt.colorbar(ims[0], ax=self.ax)
self._colorbar = self.figure.colorbar(ims[0], ax=self.ax)
self.set_quantity_label()
self._colorbar.set_label(self.quantity_label, rotation=-90, va="bottom")
self._colorbar.ax.yaxis.set_animated(self.figure.canvas.supports_blit)
Expand Down Expand Up @@ -569,6 +572,8 @@ def format_coord(x, y):
# `draw_all` is deprecated in matplotlib 3.6.0
if Version(matplotlib.__version__) <= Version("3.6.0"):
self._colorbar.draw_all()
elif isinstance(self.figure, SubFigure):
self.figure.canvas.draw_idle() # draw without rendering not supported for sub-figures
ericpre marked this conversation as resolved.
Show resolved Hide resolved
else:
self.figure.draw_without_rendering()
self._colorbar.solids.set_animated(self.figure.canvas.supports_blit)
Expand Down