Skip to content

Commit

Permalink
Adds kwarg xcorr to fftconvolve.
Browse files Browse the repository at this point in the history
Addresses issue scipy#11020 by providing a way to provide an output
specific to cross-correlation, containing both correlation values
and the lags used to determine time offset.
  • Loading branch information
benjaminr committed Nov 16, 2019
1 parent 89286b9 commit 061e059
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
17 changes: 16 additions & 1 deletion scipy/signal/signaltools.py
Expand Up @@ -429,7 +429,7 @@ def _apply_conv_mode(ret, s1, s2, mode, axes):
" 'same', or 'full'")


def fftconvolve(in1, in2, mode="full", axes=None):
def fftconvolve(in1, in2, mode="full", axes=None, xcorr=False):
"""Convolve two N-dimensional arrays using FFT.
Convolve `in1` and `in2` using the fast Fourier transform method, with
Expand Down Expand Up @@ -465,6 +465,10 @@ def fftconvolve(in1, in2, mode="full", axes=None):
Axes over which to compute the convolution.
The default is over all axes.
xcorr : bool, optional
Returns a tuple (correlation, lags) for cross-correlation. Lags can
be indexed with the np.argmax of the correlation to return the lag.
Returns
-------
out : array
Expand Down Expand Up @@ -519,6 +523,13 @@ def fftconvolve(in1, in2, mode="full", axes=None):
>>> ax_blurred.set_axis_off()
>>> fig.show()
Cross-correlation of a signal with its time-delayed self.
>>> x = np.random.random(1000)
>>> x_behind = np.pad(x, (100,0), 'constant')
>>> (correlation, lags) = signal.fftconvolve(x, x_behind[::-1], mode='full', xcorr=True)
>>> lag_behind_index = np.argmax(correlation)
>>> lag = lags[lag_behind_index]
"""
in1 = np.asarray(in1)
in2 = np.asarray(in2)
Expand All @@ -541,6 +552,10 @@ def fftconvolve(in1, in2, mode="full", axes=None):

ret = _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=True)

if xcorr:
lags = np.arange(-in2.size+1, in1.size)
return (_apply_conv_mode(ret, s1, s2, mode, axes), lags)

return _apply_conv_mode(ret, s1, s2, mode, axes)


Expand Down
14 changes: 13 additions & 1 deletion scipy/signal/tests/test_signaltools.py
Expand Up @@ -763,7 +763,19 @@ def test_many_sizes(self, n):

out = fftconvolve(a, b, 'full', axes=[0])
assert_allclose(out, expected, atol=1e-10)


def test_xcorr(self):
x = np.random.random(1000)

x_behind = np.pad(x, (100,0), 'constant')
(correlation, lags) = signal.fftconvolve(x, x_behind[::-1], mode='full', xcorr=True)
lag_behind_index = np.argmax(correlation)
assert lags[lag_behind_index] == -100

x_ahead = x[100:]
(correlation, lags) = signal.fftconvolve(x, x_ahead[::-1], mode='full', xcorr=True)
lag_ahead_index = np.argmax(correlation)
assert lags[lag_ahead_index] == 100

def fftconvolve_err(*args, **kwargs):
raise RuntimeError('Fell back to fftconvolve')
Expand Down

0 comments on commit 061e059

Please sign in to comment.