Skip to content

Commit

Permalink
Merge pull request #3192 from pietsjoh/pcolormesh
Browse files Browse the repository at this point in the history
Use pcolormesh for image plots involving non-uniform axes
  • Loading branch information
ericpre committed Sep 8, 2023
2 parents b22f2a5 + 89da454 commit 07f8177
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 17 deletions.
32 changes: 21 additions & 11 deletions hyperspy/drawing/image.py
Expand Up @@ -377,7 +377,8 @@ def plot(self, data_function_kwargs={}, **kwargs):
def _add_colorbar(self):
# Bug extend='min' or extend='both' and power law norm
# Use it when it is fixed in matplotlib
self._colorbar = plt.colorbar(self.ax.images[0], ax=self.ax)
ims = self.ax.images if len(self.ax.images) else self.ax.collections
self._colorbar = plt.colorbar(ims[0], ax=self.ax)
self.set_quantity_label()
self._colorbar.set_label(
self.quantity_label, rotation=-90, va='bottom')
Expand Down Expand Up @@ -432,7 +433,8 @@ def update(self, data_changed=True, auto_contrast=None, vmin=None,
data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
data = self._current_data = data
self._is_rgb = True
ims = self.ax.images

ims = self.ax.images if len(self.ax.images) else self.ax.collections

# Turn on centre_colormap if a diverging colormap is used.
if not self._is_rgb and self.centre_colormap == "auto":
Expand Down Expand Up @@ -532,8 +534,11 @@ def format_coord(x, y):
if self.no_nans:
data = np.nan_to_num(data)

if ims: # the images has already been drawn previously
ims[0].set_data(data)
if ims: # the images have already been drawn previously
if len(self.ax.images): # imshow
ims[0].set_data(data)
else: # pcolormesh
ims[0].set_array(data.ravel())
# update extent:
if 'x' in self.autoscale:
self._extent[0] = self.xaxis.axis[0] - self.xaxis.scale / 2
Expand Down Expand Up @@ -563,19 +568,24 @@ def format_coord(x, y):
ims[0].changed()
self.render_figure()
else: # no signal have been drawn yet
new_args = {'extent': self._extent,
'aspect': self._aspect,
'animated': self.figure.canvas.supports_blit,
}
new_args = {"animated": self.figure.canvas.supports_blit}
if not self._is_rgb:
if norm is None:
new_args.update({'vmin': vmin, 'vmax':vmax})
else:
new_args['norm'] = norm
new_args.update(kwargs)
self.ax.imshow(data, **new_args)

if self.axes_ticks == 'off':
if self.xaxis.is_uniform and self.yaxis.is_uniform:
# pcolormesh doesn't have extent and aspect as arguments
# aspect is set earlier via self.ax.set_aspect() anyways
new_args.update({"extent": self._extent, "aspect": self._aspect})
self.ax.imshow(data, **new_args)
else:
self.ax.pcolormesh(
self.xaxis.axis, self.yaxis.axis, data, **new_args
)
self.ax.invert_yaxis()
if self.axes_ticks == "off":
self.ax.set_axis_off()

def _update(self):
Expand Down
6 changes: 1 addition & 5 deletions hyperspy/signal.py
Expand Up @@ -2852,11 +2852,7 @@ def get_dynamic_explorer_wrapper(*args, **kwargs):
navigator = "slider"
elif (self.axes_manager.navigation_dimension == 1 and
self.axes_manager.signal_dimension == 1):
if (self.axes_manager.navigation_axes[0].is_uniform and
self.axes_manager.signal_axes[0].is_uniform):
navigator = "data"
else:
navigator = "spectrum"
navigator = "data"
elif self.axes_manager.navigation_dimension > 0:
if self.axes_manager.signal_dimension == 0:
navigator = self.deepcopy()
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 29 additions & 1 deletion hyperspy/tests/drawing/test_plot_signal1d.py
Expand Up @@ -211,7 +211,7 @@ def test_plot_spectra_auto_update(self):
return ax.get_figure()


class TestPlotNonLinearAxis:
class TestPlotNonUniformAxis:

def setup_method(self, method):
dict0 = {'size': 10, 'name': 'Axis0', 'units': 'A', 'scale': 0.2,
Expand All @@ -229,6 +229,34 @@ def test_plot_non_uniform_sig(self):
self.s.plot()
return self.s._plot.signal_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_non_uniform_sig_update(self):
s2 = self.s
s2.plot()
s2.axes_manager[0].index += 1
return s2._plot.signal_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_uniform_nav(self):
self.s.plot()
return self.s._plot.navigator_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_uniform_nav_update(self):
s2 = self.s
s2.plot()
s2.axes_manager[0].index += 1
return self.s._plot.navigator_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_uniform_sig(self):
self.s.plot()
return self.s._plot.signal_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_non_uniform_nav(self):
Expand Down
22 changes: 22 additions & 0 deletions hyperspy/tests/drawing/test_plot_signal2d.py
Expand Up @@ -259,6 +259,28 @@ def test_plot_non_uniform_nav(self):
self.s.plot()
return self.s._plot.navigator_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_non_uniform_2s1n_sig(self):
self.s.plot()
return self.s._plot.signal_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_non_uniform_2s1n_update_sig(self):
s2 = self.s
s2.axes_manager[0].index += 1
s2.plot()
return s2._plot.signal_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_non_uniform_2s1n_update_nav(self):
s2 = self.s
s2.axes_manager[0].index += 1
s2.plot()
return s2._plot.navigator_plot.figure

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir,
tolerance=default_tol, style=style_pytest_mpl)
def test_plot_non_uniform_sig(self):
Expand Down
3 changes: 3 additions & 0 deletions upcoming_changes/3192.new.rst
@@ -0,0 +1,3 @@
Switch to pcolormesh for image plots involving non-uniform axes.
The following cases are covered: 2D-signal with arbitrary navigation-dimension, 1D-navigation and 1D-signal (linescan).
Not covered are 2D-navigation images (still uses sliders).

0 comments on commit 07f8177

Please sign in to comment.