Skip to content

Commit

Permalink
Add some functions for spectral analysis.
Browse files Browse the repository at this point in the history
This commit adds "stft", "csd", and "welch" functions in scipy.signal.
  • Loading branch information
yotarok committed Feb 17, 2022
1 parent 924f7aa commit e085370
Show file tree
Hide file tree
Showing 4 changed files with 699 additions and 0 deletions.
346 changes: 346 additions & 0 deletions jax/_src/scipy/signal.py
Expand Up @@ -13,15 +13,21 @@
# limitations under the License.

import scipy.signal as osp_signal
import operator
import warnings

import numpy as np

import jax
import jax.numpy.fft
from jax import lax
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.third_party.scipy import signal_helper
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert


# Note: we do not re-use the code from jax.numpy.convolve here, because the handling
Expand Down Expand Up @@ -146,3 +152,343 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
coef, *_ = linalg.lstsq(A, data[sl])
data = data.at[sl].add(-jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
return jnp.moveaxis(data.reshape(shape), 0, axis)


def _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides):
"""Calculate windowed FFT in the same way the original SciPy does.
"""
if x.dtype.kind == 'i':
x = x.astype(win.dtype)

# Created strided array of data segments
if nperseg == 1 and noverlap == 0:
result = x[..., np.newaxis]
else:
step = nperseg - noverlap
*batch_shape, signal_length = x.shape
batch_shape = tuple(batch_shape)
x = x.reshape((int(np.prod(batch_shape)), signal_length))[..., np.newaxis]
result = jax.lax.conv_general_dilated_patches(
x, (nperseg,), (step,),
'VALID',
dimension_numbers=('NTC', 'OIT', 'NTC'))
result = result.reshape(batch_shape + result.shape[-2:])

# Detrend each data segment individually
result = detrend_func(result)

# Apply window by multiplication
result = win.reshape((1,) * len(batch_shape) + (1, nperseg)) * result

# Perform the fft on last axis. Zero-pads automatically
if sides == 'twosided':
return jax.numpy.fft.fft(result, n=nfft)
else:
return jax.numpy.fft.rfft(result.real, n=nfft)


def odd_ext(x, n, axis=-1):
"""Extends `x` along with `axis` by odd-extension.
This function was previously a part of "scipy.signal.signaltools" but is no
longer exposed.
Args:
x : input array
n : the number of points to be added to the both end
axis: the axis to be extended
"""
if n < 1:
return x
if n > x.shape[axis] - 1:
raise ValueError(
f"The extension length n ({n}) is too big. "
f"It must not exceed x.shape[axis]-1, which is {x.shape[axis] - 1}.")
left_end = lax.slice_in_dim(x, 0, 1, axis=axis)
left_ext = jnp.flip(lax.slice_in_dim(x, 1, n + 1, axis=axis), axis=axis)
right_end = lax.slice_in_dim(x, -1, None, axis=axis)
right_ext = jnp.flip(lax.slice_in_dim(x, -(n + 1), -1, axis=axis), axis=axis)
ext = jnp.concatenate((2 * left_end - left_ext,
x,
2 * right_end - right_ext),
axis=axis)
return ext


def _spectral_helper(x, y,
fs=1.0, window='hann', nperseg=None, noverlap=None,
nfft=None, detrend_type='constant', return_onesided=True,
scaling='density', axis=-1, mode='psd', boundary=None,
padded=False):
"""LAX-backend implementation of `scipy.signal._spectral_helper`.
Unlike the original helper function, `y` can be None for explicitly
indicating auto-spectral (non cross-spectral) computation. In addition to
this, `detrend` argument is renamed to `detrend_type` for avoiding internal
name overlap.
"""
if mode not in ('psd', 'stft'):
raise ValueError(f"Unknown value for mode {mode}, "
"must be one of: ('psd', 'stft')")

def make_pad(mode, **kwargs):
def pad(x, n, axis=-1):
pad_width = [(0, 0) for unused_n in range(x.ndim)]
pad_width[axis] = (n, n)
return jnp.pad(x, pad_width, mode, **kwargs)
return pad

boundary_funcs = {
'even': make_pad('reflect'),
'odd': odd_ext,
'constant': make_pad('edge'),
'zeros': make_pad('constant', constant_values=0.0),
None: lambda x, *args, **kwargs: x
}

# Check/ normalize inputs
if boundary not in boundary_funcs:
raise ValueError(
f"Unknown boundary option '{boundary}', "
f"must be one of: {list(boundary_funcs.keys())}")

axis = jax.core.concrete_or_error(operator.index, axis,
"axis of windowed-FFT")
axis = canonicalize_axis(axis, x.ndim)

if nperseg is not None: # if specified by user
nperseg = jax.core.concrete_or_error(int, nperseg,
"nperseg of windowed-FFT")
if nperseg < 1:
raise ValueError('nperseg must be a positive integer')
# parse window; if array like, then set nperseg = win.shape
win, nperseg = signal_helper._triage_segments(
window, nperseg, input_length=x.shape[axis])

if noverlap is None:
noverlap = nperseg // 2
else:
noverlap = jax.core.concrete_or_error(int, noverlap,
"noverlap of windowed-FFT")
if nfft is None:
nfft = nperseg
else:
nfft = jax.core.concrete_or_error(int, nfft,
"nfft of windowed-FFT")

_check_arraylike("_spectral_helper", x)
x = jnp.asarray(x)

if y is None:
outdtype = jax.dtypes.canonicalize_dtype(np.result_type(x, np.complex64))
else:
_check_arraylike("_spectral_helper", y)
y = jnp.asarray(y)
outdtype = jax.dtypes.canonicalize_dtype(
np.result_type(x, y, np.complex64))
if mode != 'psd':
raise ValueError("two-argument mode is available only when mode=='psd'")
if x.ndim != y.ndim:
raise ValueError(
"two-arguments must have the same rank ({x.ndim} vs {y.ndim}).")

# Check if we can broadcast the outer axes together
try:
outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
tuple_delete(y.shape, axis))
except ValueError as e:
raise ValueError('x and y cannot be broadcast together.') from e

# Special cases for size == 0
if y is None:
if x.size == 0:
return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape)
else:
if x.size == 0 or y.size == 0:
outshape = tuple_insert(
outershape, min([x.shape[axis], y.shape[axis]]), axis)
emptyout = jnp.zeros(outshape)
return emptyout, emptyout, emptyout

# Move time-axis to the end
if x.ndim > 1:
if axis != -1:
x = jnp.moveaxis(x, axis, -1)
if y is not None and y.ndim > 1:
y = jnp.moveaxis(y, axis, -1)

# Check if x and y are the same length, zero-pad if necessary
if y is not None:
if x.shape[-1] != y.shape[-1]:
if x.shape[-1] < y.shape[-1]:
pad_shape = list(x.shape)
pad_shape[-1] = y.shape[-1] - x.shape[-1]
x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1)
else:
pad_shape = list(y.shape)
pad_shape[-1] = x.shape[-1] - y.shape[-1]
y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1)

if nfft < nperseg:
raise ValueError('nfft must be greater than or equal to nperseg.')
if noverlap >= nperseg:
raise ValueError('noverlap must be less than nperseg.')
nstep = nperseg - noverlap

# Apply paddings
if boundary is not None:
ext_func = boundary_funcs[boundary]
x = ext_func(x, nperseg // 2, axis=-1)
if y is not None:
y = ext_func(y, nperseg // 2, axis=-1)

if padded:
# Pad to integer number of windowed segments
# I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg
zeros_shape = list(x.shape[:-1]) + [nadd]
x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1)
if y is not None:
zeros_shape = list(y.shape[:-1]) + [nadd]
y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1)

# Handle detrending and window functions
if not detrend_type:
def detrend_func(d):
return d
elif not hasattr(detrend_type, '__call__'):
def detrend_func(d):
return detrend(d, type=detrend_type, axis=-1)
elif axis != -1:
# Wrap this function so that it receives a shape that it could
# reasonably expect to receive.
def detrend_func(d):
d = jnp.moveaxis(d, axis, -1)
d = detrend_type(d)
return jnp.moveaxis(d, -1, axis)
else:
detrend_func = detrend_type

if np.result_type(win, np.complex64) != outdtype:
win = win.astype(outdtype)

# Determine scale
if scaling == 'density':
scale = 1.0 / (fs * (win * win).sum())
elif scaling == 'spectrum':
scale = 1.0 / win.sum()**2
else:
raise ValueError(f'Unknown scaling: {scaling}')
if mode == 'stft':
scale = jnp.sqrt(scale)

# Determine onesided/ two-sided
if return_onesided:
sides = 'onesided'
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
sides = 'twosided'
warnings.warn('Input data is complex, switching to '
'return_onesided=False')
else:
sides = 'twosided'

if sides == 'twosided':
freqs = jax.numpy.fft.fftfreq(nfft, 1/fs)
elif sides == 'onesided':
freqs = jax.numpy.fft.rfftfreq(nfft, 1/fs)

# Perform the windowed FFTs
result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)

if y is not None:
# All the same operations on the y data
result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
sides)
result = jnp.conjugate(result) * result_y
elif mode == 'psd':
result = jnp.conjugate(result) * result

result *= scale

if sides == 'onesided' and mode == 'psd':
end = None if nfft % 2 else -1
result = result.at[..., 1:end].mul(2)

time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1,
nperseg - noverlap) / fs
if boundary is not None:
time -= (nperseg / 2) / fs

result = result.astype(outdtype)

# All imaginary parts are zero anyways
if y is None and mode != 'stft':
result = result.real

# Move frequency axis back to axis where the data came from
result = jnp.moveaxis(result, -1, axis)

return freqs, time, result


@_wraps(osp_signal.stft)
def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,
detrend=False, return_onesided=True, boundary='zeros', padded=True,
axis=-1):
freqs, time, Zxx = _spectral_helper(x, None, fs, window, nperseg, noverlap,
nfft, detrend, return_onesided,
scaling='spectrum', axis=axis,
mode='stft', boundary=boundary,
padded=padded)

return freqs, time, Zxx


_csd_description = """
The original SciPy function exhibits slightly different behavior between
``csd(x, x)``` and ```csd(x, x.copy())```. The LAX-backend version is designed
to follow the latter behavior. For using the former behavior, call this
function as `csd(x, None)`."""


@_wraps(osp_signal.csd, lax_description=_csd_description)
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density',
axis=-1, average='mean'):
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
detrend, return_onesided, scaling, axis,
mode='psd')
if y is not None:
Pxy = Pxy + 0j # Ensure complex output when x is not y

# Average over windows.
if Pxy.ndim >= 2 and Pxy.size > 0:
if Pxy.shape[-1] > 1:
if average == 'median':
bias = signal_helper._median_bias(Pxy.shape[-1]).astype(Pxy.dtype)
if jnp.iscomplexobj(Pxy):
Pxy = (jnp.median(jnp.real(Pxy), axis=-1)
+ 1j * jnp.median(jnp.imag(Pxy), axis=-1))
else:
Pxy = jnp.median(Pxy, axis=-1)
Pxy /= bias
elif average == 'mean':
Pxy = Pxy.mean(axis=-1)
else:
raise ValueError(f'average must be "median" or "mean", got {average}')
else:
Pxy = jnp.reshape(Pxy, Pxy.shape[:-1])

return freqs, Pxy


@_wraps(osp_signal.welch)
def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density',
axis=-1, average='mean'):
freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, scaling=scaling,
axis=axis, average=average)

return freqs, Pxx.real

0 comments on commit e085370

Please sign in to comment.