Skip to content

Commit

Permalink
Merge 6cb48fd into aa30e9d
Browse files Browse the repository at this point in the history
  • Loading branch information
vidartf committed Aug 5, 2016
2 parents aa30e9d + 6cb48fd commit e080508
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 29 deletions.
119 changes: 90 additions & 29 deletions hyperspy/_signals/signal2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
import numpy as np
import numpy.ma as ma
import scipy as sp
import warnings
import logging
from scipy.fftpack import fftn, ifftn
from skimage.feature.register_translation import _upsampled_dft

from hyperspy.defaults_parser import preferences
from hyperspy.external.progressbar import progressbar
from hyperspy.misc.math_tools import symmetrize, antisymmetrize
from hyperspy.signal import BaseSignal
from hyperspy._signals.common_signal2d import CommonSignal2D
from hyperspy.docstrings.plot import BASE_PLOT_DOCSTRING, PLOT2D_DOCSTRING, KWARGS_DOCSTRING
from hyperspy.docstrings.plot import (
BASE_PLOT_DOCSTRING, PLOT2D_DOCSTRING, KWARGS_DOCSTRING)


_logger = logging.getLogger(__name__)


def shift_image(im, shift, interpolation_order=1, fill_value=np.nan):
Expand Down Expand Up @@ -87,19 +92,19 @@ def fft_correlation(in1, in2, normalize=False):
size = s1 + s2 - 1
# Use 2**n-sized FFT
fsize = 2 ** np.ceil(np.log2(size))
IN1 = fftn(in1, fsize)
IN1 *= fftn(in2, fsize).conjugate()
fprod = fftn(in1, fsize)
fprod *= fftn(in2, fsize).conjugate()
if normalize is True:
ret = ifftn(np.nan_to_num(IN1 / np.absolute(IN1))).real.copy()
ret = ifftn(np.nan_to_num(fprod / np.absolute(fprod))).real.copy()
else:
ret = ifftn(IN1).real.copy()
del IN1
return ret
ret = ifftn(fprod).real.copy()
return ret, fprod


def estimate_image_shift(ref, image, roi=None, sobel=True,
medfilter=True, hanning=True, plot=False,
dtype='float', normalize_corr=False,):
dtype='float', normalize_corr=False,
sub_pixel_factor=1):
"""Estimate the shift in a image using phase correlation
This method can only estimate the shift by comparing
Expand All @@ -111,6 +116,10 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
Parameters
----------
ref : 2D numpy.ndarray
Reference image
image : 2D numpy.ndarray
Image to register
roi : tuple of ints (top, bottom, left, right)
Define the region of interest
sobel : bool
Expand All @@ -122,16 +131,14 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
plot : bool
If True plots the images after applying the filters and
the phase correlation
reference : \'current\' | \'cascade\'
If \'current\' (default) the image at the current
coordinates is taken as reference. If \'cascade\' each image
is aligned with the previous one.
dtype : str or dtype
Typecode or data-type in which the calculations must be
performed.
normalize_corr : bool
If True use phase correlation instead of standard correlation
sub_pixel_factor : float
Estimate shifts with a sub-pixel accuracy of 1/sub_pixel_factor parts
of a pixel. Default is 1, i.e. no sub-pixel accuracy.
Returns
-------
Expand All @@ -142,6 +149,7 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
The maximum value of the correlation
"""

# Make a copy of the images to avoid modifying them
ref = ref.copy().astype(dtype)
image = image.copy().astype(dtype)
Expand All @@ -162,9 +170,8 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
im[:] = sp.signal.medfilt(im)
if sobel is True:
im[:] = sobel_filter(im)

phase_correlation = fft_correlation(ref, image,
normalize=normalize_corr)
phase_correlation, image_product = fft_correlation(
ref, image, normalize=normalize_corr)

# Estimate the shift by getting the coordinates of the maximum
argmax = np.unravel_index(np.argmax(phase_correlation),
Expand All @@ -176,6 +183,33 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
shift1 = argmax[1] if argmax[1] < threshold[1] else \
argmax[1] - phase_correlation.shape[1]
max_val = phase_correlation.max()
shifts = np.array((shift0, shift1))

# The following code is more or less copied from
# skimage.feature.register_feature, to gain access to the maximum value:
if sub_pixel_factor != 1:
# Initial shift estimate in upsampled grid
shifts = np.round(shifts * sub_pixel_factor) / sub_pixel_factor
upsampled_region_size = np.ceil(sub_pixel_factor * 1.5)
# Center of output array at dftshift + 1
dftshift = np.fix(upsampled_region_size / 2.0)
sub_pixel_factor = np.array(sub_pixel_factor, dtype=np.float64)
normalization = (image_product.size * sub_pixel_factor ** 2)
# Matrix multiply DFT around the current shift estimate
sample_region_offset = dftshift - shifts*sub_pixel_factor
cross_correlation = _upsampled_dft(image_product.conj(),
upsampled_region_size,
sub_pixel_factor,
sample_region_offset).conj()
cross_correlation /= normalization
# Locate maximum and map back to original pixel grid
maxima = np.array(np.unravel_index(
np.argmax(np.abs(cross_correlation)),
cross_correlation.shape),
dtype=np.float64)
maxima -= dftshift
shifts = shifts + maxima / sub_pixel_factor
max_val = cross_correlation.max()

# Plot on demand
if plot is True:
Expand All @@ -184,15 +218,15 @@ def estimate_image_shift(ref, image, roi=None, sobel=True,
axarr[1].imshow(image)
axarr[2].imshow(phase_correlation)
axarr[0].set_title('Reference')
axarr[1].set_title('Signal2D')
axarr[1].set_title('Signal')
axarr[2].set_title('Phase correlation')
plt.show()
# Liberate the memory. It is specially necessary if it is a
# memory map
del ref
del image

return -np.array((shift0, shift1)), max_val
return -shifts, max_val


class Signal2D(BaseSignal, CommonSignal2D):
Expand Down Expand Up @@ -242,8 +276,8 @@ def create_model(self, dictionary=None):
Parameters
__________
dictionary : {None, dict}, optional
A dictionary to be used to recreate a model. Usually generated using
:meth:`hyperspy.model.as_dictionary`
A dictionary to be used to recreate a model. Usually generated
using :meth:`hyperspy.model.as_dictionary`
Returns
-------
A Model class
Expand All @@ -262,15 +296,19 @@ def estimate_shift2D(self,
hanning=True,
plot=False,
dtype='float',
show_progressbar=None):
show_progressbar=None,
sub_pixel_factor=1):
"""Estimate the shifts in a image using phase correlation
This method can only estimate the shift by comparing
bidimensional features that should not change position
between frames. To decrease the memory usage, the time of
computation and the accuracy of the results it is convenient
to select a region of interest by setting the roi keyword.
Parameters
----------
reference : {'current', 'cascade' ,'stat'}
If 'current' (default) the image at the current
coordinates is taken as reference. If 'cascade' each image
Expand Down Expand Up @@ -306,20 +344,30 @@ def estimate_shift2D(self,
show_progressbar : None or bool
If True, display a progress bar. If None the default is set in
`preferences`.
sub_pixel_factor : float
Estimate shifts with a sub-pixel accuracy of 1/sub_pixel_factor
parts of a pixel. Default is 1, i.e. no sub-pixel accuracy.
Returns
-------
list of applied shifts
Notes
-----
The statistical analysis approach to the translation estimation
when using `reference`='stat' roughly follows [1]_ . If you use
it please cite their article.
References
----------
.. [1] Schaffer, Bernhard, Werner Grogger, and Gerald
Kothleitner. “Automated Spatial Drift Correction for EFTEM
Signal2D Series.”
Image Series.”
Ultramicroscopy 102, no. 1 (December 2004): 27–36.
"""
if show_progressbar is None:
show_progressbar = preferences.General.show_progressbar
Expand Down Expand Up @@ -353,7 +401,8 @@ def estimate_shift2D(self,
hanning=hanning,
normalize_corr=normalize_corr,
plot=plot,
dtype=dtype)
dtype=dtype,
sub_pixel_factor=sub_pixel_factor)
np.fill_diagonal(pcarray['max_value'], max_value)
pbar_max = nrows * images_number
else:
Expand All @@ -372,7 +421,8 @@ def estimate_shift2D(self,
nshift, max_val = estimate_image_shift(
ref, im, roi=roi, sobel=sobel, medfilter=medfilter,
hanning=hanning, plot=plot,
normalize_corr=normalize_corr, dtype=dtype)
normalize_corr=normalize_corr, dtype=dtype,
sub_pixel_factor=sub_pixel_factor)
if reference == 'cascade':
shift += nshift
ref = im.copy()
Expand All @@ -396,7 +446,8 @@ def estimate_shift2D(self,
hanning=hanning,
normalize_corr=normalize_corr,
plot=plot,
dtype=dtype)
dtype=dtype,
sub_pixel_factor=sub_pixel_factor)

pcarray[i1, i2] = max_value, nshift
del im2
Expand Down Expand Up @@ -439,12 +490,15 @@ def align2D(self, crop=True, fill_value=np.nan, shifts=None, expand=False,
dtype='float',
correlation_threshold=None,
chunk_size=30,
interpolation_order=1):
interpolation_order=1,
sub_pixel_factor=1):
"""Align the images in place using user provided shifts or by
estimating the shifts.
Please, see `estimate_shift2D` docstring for details
on the rest of the parameters not documented in the following
section
Parameters
----------
crop : bool
Expand All @@ -462,21 +516,27 @@ def align2D(self, crop=True, fill_value=np.nan, shifts=None, expand=False,
interpolation_order: int, default 1.
The order of the spline interpolation. Default is 1, linear
interpolation.
Returns
-------
shifts : np.array
The shifts are returned only if `shifts` is None
Notes
-----
The statistical analysis approach to the translation estimation
when using `reference`='stat' roughly follows [1]_ . If you use
it please cite their article.
References
----------
.. [1] Schaffer, Bernhard, Werner Grogger, and Gerald
Kothleitner. “Automated Spatial Drift Correction for EFTEM
Signal2D Series.”
Image Series.”
Ultramicroscopy 102, no. 1 (December 2004): 27–36.
"""
self._check_signal_dimension_equals_two()
if shifts is None:
Expand All @@ -490,7 +550,8 @@ def align2D(self, crop=True, fill_value=np.nan, shifts=None, expand=False,
dtype=dtype,
correlation_threshold=correlation_threshold,
normalize_corr=normalize_corr,
chunk_size=chunk_size)
chunk_size=chunk_size,
sub_pixel_factor=sub_pixel_factor)
return_shifts = True
else:
return_shifts = False
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
install_req = ['scipy',
'ipython>=2.0',
'matplotlib>=1.2',
'scikit-image',
'numpy>=1.10',
'traits>=4.5.0',
'traitsui>=5.0',
Expand Down

0 comments on commit e080508

Please sign in to comment.