Skip to content

Commit

Permalink
Merge pull request #1290 from to266/ENH_subpixel_2Dshifts
Browse files Browse the repository at this point in the history
Add sub-pixel accuracy to 2D alignment
  • Loading branch information
francisco-dlp committed Aug 29, 2018
2 parents eff5099 + bb2321f commit ade6153
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 33 deletions.
107 changes: 78 additions & 29 deletions hyperspy/_signals/signal2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import scipy as sp
import logging
from scipy.fftpack import fftn, ifftn
from skimage.feature.register_translation import _upsampled_dft

from hyperspy.defaults_parser import preferences
from hyperspy.external.progressbar import progressbar
Expand Down Expand Up @@ -96,19 +97,18 @@ def fft_correlation(in1, in2, normalize=False):
size = s1 + s2 - 1
# Use 2**n-sized FFT
fsize = (2 ** np.ceil(np.log2(size))).astype("int")
IN1 = fftn(in1, fsize)
IN1 *= fftn(in2, fsize).conjugate()
fprod = fftn(in1, fsize)
fprod *= fftn(in2, fsize).conjugate()
if normalize is True:
ret = ifftn(np.nan_to_num(IN1 / np.absolute(IN1))).real.copy()
else:
ret = ifftn(IN1).real.copy()
del IN1
return ret
fprod = np.nan_to_num(fprod / np.absolute(fprod))
ret = ifftn(fprod).real.copy()
return ret, fprod


def estimate_image_shift(ref, image, roi=None, sobel=True,
medfilter=True, hanning=True, plot=False,
dtype='float', normalize_corr=False,
sub_pixel_factor=1,
return_maxval=True):
"""Estimate the shift in a image using phase correlation
Expand All @@ -121,6 +121,10 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
Parameters
----------
ref : 2D numpy.ndarray
Reference image
image : 2D numpy.ndarray
Image to register
roi : tuple of ints (top, bottom, left, right)
Define the region of interest
sobel : bool
Expand All @@ -133,16 +137,18 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
If True, plots the images after applying the filters and the phase
correlation. If a figure instance, the images will be plotted to the
given figure.
reference : \'current\' | \'cascade\'
If \'current\' (default) the image at the current
coordinates is taken as reference. If \'cascade\' each image
reference : 'current' | 'cascade'
If 'current' (default) the image at the current
coordinates is taken as reference. If 'cascade' each image
is aligned with the previous one.
dtype : str or dtype
Typecode or data-type in which the calculations must be
performed.
normalize_corr : bool
If True use phase correlation instead of standard correlation
sub_pixel_factor : float
Estimate shifts with a sub-pixel accuracy of 1/sub_pixel_factor parts
of a pixel. Default is 1, i.e. no sub-pixel accuracy.
Returns
-------
Expand All @@ -153,6 +159,7 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
The maximum value of the correlation
"""

ref, image = da.compute(ref, image)
# Make a copy of the images to avoid modifying them
ref = ref.copy().astype(dtype)
Expand All @@ -174,9 +181,8 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
im[:] = sp.signal.medfilt(im)
if sobel is True:
im[:] = sobel_filter(im)

phase_correlation = fft_correlation(ref, image,
normalize=normalize_corr)
phase_correlation, image_product = fft_correlation(
ref, image, normalize=normalize_corr)

# Estimate the shift by getting the coordinates of the maximum
argmax = np.unravel_index(np.argmax(phase_correlation),
Expand All @@ -187,19 +193,46 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
argmax[0] - phase_correlation.shape[0]
shift1 = argmax[1] if argmax[1] < threshold[1] else \
argmax[1] - phase_correlation.shape[1]
max_val = phase_correlation.max()
max_val = phase_correlation.real.max()
shifts = np.array((shift0, shift1))

# The following code is more or less copied from
# skimage.feature.register_feature, to gain access to the maximum value:
if sub_pixel_factor != 1:
# Initial shift estimate in upsampled grid
shifts = np.round(shifts * sub_pixel_factor) / sub_pixel_factor
upsampled_region_size = np.ceil(sub_pixel_factor * 1.5)
# Center of output array at dftshift + 1
dftshift = np.fix(upsampled_region_size / 2.0)
sub_pixel_factor = np.array(sub_pixel_factor, dtype=np.float64)
normalization = (image_product.size * sub_pixel_factor ** 2)
# Matrix multiply DFT around the current shift estimate
sample_region_offset = dftshift - shifts * sub_pixel_factor
correlation = _upsampled_dft(image_product.conj(),
upsampled_region_size,
sub_pixel_factor,
sample_region_offset).conj()
correlation /= normalization
# Locate maximum and map back to original pixel grid
maxima = np.array(np.unravel_index(
np.argmax(np.abs(correlation)),
correlation.shape),
dtype=np.float64)
maxima -= dftshift
shifts = shifts + maxima / sub_pixel_factor
max_val = correlation.real.max()

# Plot on demand
if plot is True or isinstance(plot, plt.Figure):
if isinstance(plot, plt.Figure):
f = plot
fig = plot
axarr = plot.axes
if len(axarr) < 3:
for i in range(3):
f.add_subplot(1, 3, i)
axarr = plot.axes
fig.add_subplot(1, 3, i + 1)
axarr = fig.axes
else:
f, axarr = plt.subplots(1, 3)
fig, axarr = plt.subplots(1, 3)
full_plot = len(axarr[0].images) == 0
if full_plot:
axarr[0].set_title('Reference')
Expand All @@ -217,15 +250,15 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
axarr[1].images[0].set_data(image)
axarr[2].images[0].set_data(np.fft.fftshift(phase_correlation))
# TODO: Renormalize images
f.canvas.draw_idle()
fig.canvas.draw_idle()
# Liberate the memory. It is specially necessary if it is a
# memory map
del ref
del image
if return_maxval:
return -np.array((shift0, shift1)), max_val
return -shifts, max_val
else:
return -np.array((shift0, shift1))
return -shifts


class Signal2D(BaseSignal, CommonSignal2D):
Expand Down Expand Up @@ -301,8 +334,10 @@ def estimate_shift2D(self,
hanning=True,
plot=False,
dtype='float',
show_progressbar=None):
show_progressbar=None,
sub_pixel_factor=1):
"""Estimate the shifts in a image using phase correlation
This method can only estimate the shift by comparing
bidimensional features that should not change position
between frames. To decrease the memory usage, the time of
Expand All @@ -311,6 +346,7 @@ def estimate_shift2D(self,
Parameters
----------
reference : {'current', 'cascade' ,'stat'}
If 'current' (default) the image at the current
coordinates is taken as reference. If 'cascade' each image
Expand Down Expand Up @@ -348,22 +384,28 @@ def estimate_shift2D(self,
show_progressbar : None or bool
If True, display a progress bar. If None the default is set in
`preferences`.
sub_pixel_factor : float
Estimate shifts with a sub-pixel accuracy of 1/sub_pixel_factor
parts of a pixel. Default is 1, i.e. no sub-pixel accuracy.
Returns
-------
list of applied shifts
Notes
-----
The statistical analysis approach to the translation estimation
when using `reference`='stat' roughly follows [1]_ . If you use
it please cite their article.
References
----------
.. [1] Schaffer, Bernhard, Werner Grogger, and Gerald
Kothleitner. “Automated Spatial Drift Correction for EFTEM
Signal2D Series.”
Image Series.”
Ultramicroscopy 102, no. 1 (December 2004): 27–36.
"""
Expand Down Expand Up @@ -402,7 +444,8 @@ def estimate_shift2D(self,
hanning=hanning,
normalize_corr=normalize_corr,
plot=plot,
dtype=dtype)
dtype=dtype,
sub_pixel_factor=sub_pixel_factor)
np.fill_diagonal(pcarray['max_value'], max_value)
pbar_max = nrows * images_number
else:
Expand All @@ -421,7 +464,8 @@ def estimate_shift2D(self,
nshift, max_val = estimate_image_shift(
ref, im, roi=roi, sobel=sobel, medfilter=medfilter,
hanning=hanning, plot=plot,
normalize_corr=normalize_corr, dtype=dtype)
normalize_corr=normalize_corr, dtype=dtype,
sub_pixel_factor=sub_pixel_factor)
if reference == 'cascade':
shift += nshift
ref = im.copy()
Expand All @@ -445,8 +489,8 @@ def estimate_shift2D(self,
hanning=hanning,
normalize_corr=normalize_corr,
plot=plot,
dtype=dtype)

dtype=dtype,
sub_pixel_factor=sub_pixel_factor)
pcarray[i1, i2] = max_value, nshift
del im2
pbar.update(1)
Expand Down Expand Up @@ -489,10 +533,12 @@ def align2D(self, crop=True, fill_value=np.nan, shifts=None, expand=False,
correlation_threshold=None,
chunk_size=30,
interpolation_order=1,
sub_pixel_factor=1,
show_progressbar=None,
parallel=None):
"""Align the images in place using user provided shifts or by
estimating the shifts.
Please, see `estimate_shift2D` docstring for details
on the rest of the parameters not documented in the following
section
Expand Down Expand Up @@ -523,15 +569,17 @@ def align2D(self, crop=True, fill_value=np.nan, shifts=None, expand=False,
Notes
-----
The statistical analysis approach to the translation estimation
when using `reference`='stat' roughly follows [1]_ . If you use
it please cite their article.
References
----------
.. [1] Schaffer, Bernhard, Werner Grogger, and Gerald
Kothleitner. “Automated Spatial Drift Correction for EFTEM
Signal2D Series.”
Image Series.”
Ultramicroscopy 102, no. 1 (December 2004): 27–36.
"""
Expand All @@ -550,6 +598,7 @@ def align2D(self, crop=True, fill_value=np.nan, shifts=None, expand=False,
correlation_threshold=correlation_threshold,
normalize_corr=normalize_corr,
chunk_size=chunk_size,
sub_pixel_factor=sub_pixel_factor,
show_progressbar=show_progressbar)
return_shifts = True
else:
Expand Down
34 changes: 30 additions & 4 deletions hyperspy/tests/signal/test_2D_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,26 @@
# You should have received a copy of the GNU General Public License
# along with HyperSpy. If not, see <http://www.gnu.org/licenses/>.


import sys
from unittest import mock


import numpy.testing as npt
import numpy as np
from scipy.misc import face, ascent
from scipy.ndimage import fourier_shift
import pytest

import hyperspy.api as hs
from hyperspy.decorators import lazifyTestClass


def _generate_parameters():
parameters = []
for normalize_corr in [False, True]:
for reference in ['current', 'cascade', 'stat']:
parameters.append([normalize_corr, reference])
return parameters


@lazifyTestClass
class TestSubPixelAlign:

Expand Down Expand Up @@ -58,7 +64,26 @@ def test_align_subpix(self):
shifts = self.shifts
s.align2D(shifts=shifts)
# Compare by broadcasting
np.testing.assert_allclose(s.data[4], s.data[0], rtol=1)
np.testing.assert_allclose(s.data[4], s.data[0], rtol=0.5)

@pytest.mark.parametrize(("normalize_corr", "reference"),
_generate_parameters())
def test_estimate_subpix(self, normalize_corr, reference):
s = self.signal
shifts = s.estimate_shift2D(sub_pixel_factor=200,
normalize_corr=normalize_corr)
np.testing.assert_allclose(shifts, self.shifts, rtol=0.2, atol=0.2,
verbose=True)

@pytest.mark.parametrize(("plot"), [True, 'reuse'])
def test_estimate_subpix_plot(self, mpl_cleanup, plot):
# To avoid this function plotting many figures and holding the test, we
# make sure the backend is set to `agg` in case it is set to something
# else in the testing environment
import matplotlib.pyplot as plt
plt.switch_backend('agg')
s = self.signal
s.estimate_shift2D(sub_pixel_factor=200, plot=plot)


@lazifyTestClass
Expand Down Expand Up @@ -140,6 +165,7 @@ def test_add_ramp_lazy():
s.add_ramp(-1, -1, -4)
npt.assert_almost_equal(s.data.compute(), 0)


if __name__ == '__main__':
import pytest
pytest.main(__name__)

0 comments on commit ade6153

Please sign in to comment.