Add istft to jax.scipy.signal.
yotarok committed Apr 1, 2022
1 parent a68b0f3 commit a7fd751
135 changes: 135 additions & 0 deletions jax/_src/scipy/
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,138 @@ def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
axis=axis, average=average)

return freqs, Pxx.real

def _overlap_and_add(x, step_size):
"""Utility function compatible with tf.signal.overlap_and_add.
x: An array with `(..., frames, frame_length)`-shape.
step_size: An integer denoting overlap offsets. Must be less than
An array with `(..., output_size)`-shape containing overlapped signal.
_check_arraylike("_overlap_and_add", x)
step_size = jax.core.concrete_or_error(int, step_size,
"step_size for overlap_and_add")
if x.ndim < 2:
raise ValueError('Input must have (..., frames, frame_length) shape.')

*batch_shape, nframes, segment_len = x.shape
flat_batchsize =, dtype=np.int64)
x = x.reshape((flat_batchsize, nframes, segment_len))
output_size = step_size * (nframes - 1) + segment_len
nstep_per_segment = 1 + (segment_len - 1) // step_size

# Here, we use shorter notation for axes.
# B: batch_size, N: nframes, S: nstep_per_segment,
# T: segment_len divided by S

padded_segment_len = nstep_per_segment * step_size
x = jnp.pad(x, ((0, 0), (0, 0), (0, padded_segment_len - segment_len)))
x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size))

# For obtaining shifted signals, this routine reinterprets flattened array
# with a shrinked axis. With appropriate truncation/ padding, this operation
# pushes the last padded elements of the previous row to the head of the
# current row.
# See implementation of `overlap_and_add` in Tensorflow for details.
x = x.transpose((0, 2, 1, 3)) # x: (B, S, N, T)
x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0))) # x: (B, S, N*2, T)
shrinked = x.shape[2] - 1
x = x.reshape((flat_batchsize, -1))
x = x[:, :(nstep_per_segment * shrinked * step_size)]
x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size))

# Finally, sum shifted segments, and truncate results to the output_size.
x = x.sum(axis=1)[:, :output_size]
return x.reshape(tuple(batch_shape) + (-1,))

def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2):
# Input validation
_check_arraylike("istft", Zxx)
if Zxx.ndim < 2:
raise ValueError('Input stft must be at least 2d!')
freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)
time_axis = canonicalize_axis(time_axis, Zxx.ndim)
if freq_axis == time_axis:
raise ValueError('Must specify differing time and frequency axes!')

Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype(
np.result_type(Zxx, np.complex64)))

n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided
else Zxx.shape[freq_axis])

nperseg = jax.core.concrete_or_error(int, nperseg or n_default,
"nperseg: segment length of STFT")
if nperseg < 1:
raise ValueError('nperseg must be a positive integer')

if nfft is None:
nfft = n_default
if input_onesided and nperseg == n_default + 1:
nfft += 1 # Odd nperseg, no FFT padding
nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
if nfft < nperseg:
raise ValueError(
f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).')

noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2,
"noverlap of STFT")
if noverlap >= nperseg:
raise ValueError('noverlap must be less than nperseg.')
nstep = nperseg - noverlap

# Rearrange axes if necessary
if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2:
outer_idxs = tuple(
idx for idx in range(Zxx.ndim) if idx not in {time_axis, freq_axis})
Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis))

# Perform IFFT
ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft
# xsubs: [..., T, N], N is the number of frames, T is the frame length.
xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :]

# Get window as array
if isinstance(window, (str, tuple)):
win = osp_signal.get_window(window, nperseg)
win = jnp.asarray(win)
win = jnp.asarray(window)
if len(win.shape) != 1:
raise ValueError('window must be 1-D')
if win.shape[0] != nperseg:
raise ValueError('window must have length of {0}'.format(nperseg))
win = win.astype(xsubs.dtype)

xsubs *= win.sum() # This takes care of the 'spectrum' scaling

# make win broadcastable over xsubs
win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1,))
x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)

# Remove extension points
if boundary:
x = x[..., nperseg//2:-(nperseg//2)]
norm = norm[..., nperseg//2:-(nperseg//2)]
x /= jnp.where(norm > 1e-10, norm, 1.0)

# Put axes back
if x.ndim > 1:
if time_axis != Zxx.ndim - 1:
if freq_axis < time_axis:
time_axis -= 1
x = jnp.moveaxis(x, -1, time_axis)

time = jnp.arange(x.shape[0]) / fs
return time, x
1 change: 1 addition & 0 deletions jax/scipy/
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
correlate2d as correlate2d,
detrend as detrend,
csd as csd,
istft as istft,
stft as stft,
welch as welch,
64 changes: 64 additions & 0 deletions tests/
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from functools import partial
import unittest
import warnings

from absl.testing import absltest, parameterized

Expand Down Expand Up @@ -47,6 +48,12 @@
((3, 17, 2), (3, 12, 2), 9, 3, 1),
welch_test_shapes = stft_test_shapes
istft_test_shapes = [
# (input_shape, nperseg, noverlap, timeaxis, freqaxis)
((3, 2, 64, 31), 100, 75, -1, -2),
((17, 8, 5), 13, 7, 0, 1),
((65, 24), 24, 7, -2, -1),

default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
Expand Down Expand Up @@ -376,6 +383,63 @@ def testWelchWithDefaultStepArgsAgainstNumpy(
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
"onesided": onesided, "boundary": boundary,
"timeaxis": timeaxis, "freqaxis": freqaxis}
for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes
for dtype in default_dtypes
for fs in [1.0, 16000.0]
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
for onesided in [False, True]
for boundary in [False, True]))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm 5.1
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, onesided, boundary,
timeaxis, freqaxis):
if not onesided:
new_freq_len = (shape[freqaxis] - 1) * 2
shape = shape[:freqaxis] + (new_freq_len ,) + shape[freqaxis + 1:]

def osp_fun(x, fs):
# Ignore UserWarning in osp so we can also test over ill-posed cases.
with warnings.catch_warnings():
result = osp_signal.istft(
fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
nfft=nfft, input_onesided=onesided, boundary=boundary,
time_axis=timeaxis, freq_axis=freqaxis)
return result

jsp_fun = partial(jsp_signal.istft,
window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
input_onesided=onesided, boundary=boundary,
time_axis=timeaxis, freq_axis=freqaxis)

tol = {
np.float32: 1e-4, np.float64: 1e-6,
np.complex64: 1e-4, np.complex128: 1e-6
if jtu.device_under_test() == 'tpu':
tol = _TPU_FFT_TOL

rng = jtu.rand_default(self.rng())
rng_fs = jtu.rand_uniform(self.rng(), 1.0, 16000.0)
args_maker = lambda: [rng(shape, dtype), rng_fs((), np.float)]

# Here, dtype of output signal is different depending on osp versions,
# and so depending on the test environment. Thus, dtype check is disabled.
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol,
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

if __name__ == "__main__":

