Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Updating plot in estimate_image_shift #753

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 43 additions & 15 deletions hyperspy/_signals/signal2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import numpy.ma as ma
import scipy as sp
import warnings
import logging
from scipy.fftpack import fftn, ifftn

from hyperspy.defaults_parser import preferences
Expand All @@ -31,6 +31,9 @@
from hyperspy.docstrings.plot import BASE_PLOT_DOCSTRING, PLOT2D_DOCSTRING, KWARGS_DOCSTRING


_logger = logging.getLogger(__name__)


def shift_image(im, shift, interpolation_order=1, fill_value=np.nan):
fractional, integral = np.modf(shift)
if fractional.any():
Expand Down Expand Up @@ -119,9 +122,10 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
apply a median filter for noise reduction
hanning : bool
Apply a 2d hanning filter
plot : bool
If True plots the images after applying the filters and
the phase correlation
plot : bool | matplotlib.Figure
If True, plots the images after applying the filters and the phase
correlation. If a figure instance, the images will be plotted to the
given figure.
reference : \'current\' | \'cascade\'
If \'current\' (default) the image at the current
coordinates is taken as reference. If \'cascade\' each image
Expand Down Expand Up @@ -178,15 +182,34 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
max_val = phase_correlation.max()

# Plot on demand
if plot is True:
f, axarr = plt.subplots(1, 3)
axarr[0].imshow(ref)
axarr[1].imshow(image)
axarr[2].imshow(phase_correlation)
axarr[0].set_title('Reference')
axarr[1].set_title('Signal2D')
axarr[2].set_title('Phase correlation')
plt.show()
if plot is True or isinstance(plot, plt.Figure):
if isinstance(plot, plt.Figure):
f = plot
axarr = plot.axes
if len(axarr) < 3:
for i in range(3):
f.add_subplot(1, 3, i)
axarr = plot.axes
else:
f, axarr = plt.subplots(1, 3)
full_plot = len(axarr[0].images) == 0
if full_plot:
axarr[0].set_title('Reference')
axarr[1].set_title('Image')
axarr[2].set_title('Phase correlation')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The title should be "correlation" when normalize_corr is False.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simply the existing behavior, and as such not really part of this PR :)

axarr[0].imshow(ref)
axarr[1].imshow(image)
d = (np.array(phase_correlation.shape) - 1) // 2
extent = [-d[1], d[1], -d[0], d[0]]
axarr[2].imshow(np.fft.fftshift(phase_correlation),
extent=extent)
plt.show()
else:
axarr[0].images[0].set_data(ref)
axarr[1].images[0].set_data(image)
axarr[2].images[0].set_data(np.fft.fftshift(phase_correlation))
# TODO: Renormalize images
f.canvas.draw()
# Liberate the memory. It is specially necessary if it is a
# memory map
del ref
Expand Down Expand Up @@ -297,9 +320,11 @@ def estimate_shift2D(self,
apply a median filter for noise reduction
hanning : bool
Apply a 2d hanning filter
plot : bool
plot : bool or "reuse"
If True plots the images after applying the filters and
the phase correlation
the phase correlation. If 'reuse', it will also plot the images,
but it will only use one figure, and continously update the images
in that figure as it progresses through the stack.
dtype : str or dtype
Typecode or data-type in which the calculations must be
performed.
Expand Down Expand Up @@ -336,6 +361,9 @@ def estimate_shift2D(self,
shifts = []
nrows = None
images_number = self.axes_manager._max_index + 1
if plot == 'reuse':
# Reuse figure for plots
plot = plt.figure()
if reference == 'stat':
nrows = images_number if chunk_size is None else \
min(images_number, chunk_size)
Expand Down