Skip to content

"Slow operation alarm" in jax.scipy.signal.stft #28614

@davidmarttila

Description

@davidmarttila

Description

Hello! Out of curiosity, I ran the following benchmark script to compare STFT performance between JAX and PyTorch:

import torch
import jax
import jax.numpy as jnp
from time import perf_counter
from functools import partial

n_samples = 100000

batch_size = 256

torch_device = 'cuda'
n_fft = 2048

torch_window = torch.hann_window(n_fft, device=torch_device)
jax_window = jax.numpy.hanning(n_fft)

def torch_stft(x):
    stft = torch.stft(x, n_fft=n_fft, hop_length=n_fft//2, window=torch_window, center=True, pad_mode='constant', return_complex=True)
    return stft.abs().mean()

def jax_stft(x):
    _, _, stft = jax.scipy.signal.stft(x, window=jax_window, nperseg=n_fft, noverlap=n_fft//2)
    return jnp.abs(stft)

@jax.jit
def jax_stft_vmap(x):
    return jax.vmap(jax_stft)(x).mean()


n_loops = 1000

key = jax.random.PRNGKey(0)

key, subkey = jax.random.split(key, 2)
jax_noise = jax.random.uniform(subkey, (batch_size, n_samples))

jax_stft_vmap(jax_noise).block_until_ready()

start = perf_counter()
for _ in range(n_loops):
    jax_stft_vmap(jax_noise)
end = perf_counter()
print(f"JAX STFT: {1000 * (end - start) / n_loops:.4f} ms/it avg")

torch_noise = torch.rand((batch_size, n_samples), device=torch_device)
start = perf_counter()
for _ in range(n_loops):
    torch_stft(torch_noise)
end = perf_counter()
print(f"PyTorch STFT: {1000 * (end - start) / n_loops:.4f} ms/it avg")

I was suprised by the output; it measures the JAX STFT to be significantly slower, but also explicitly warns about an operation seemingly taking longer than expected:

→ uv run --with torch --with "jax[cuda12]" main.py
2025-05-08 15:23:34.079270: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng0{} for conv %cudnn-conv.1 = (f32[256,2048,99]{2,1,0}, u8[0]{0}) custom-call(%bitcast.100, %bitcast.106), window={size=2048 stride=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(jax_stft_vmap)/jit(main)/conv_general_dilated" source_file="/homes/ds011/code/main.py" source_line=22}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-05-08 15:23:35.866923: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 2.787862855s
Trying algorithm eng0{} for conv %cudnn-conv.1 = (f32[256,2048,99]{2,1,0}, u8[0]{0}) custom-call(%bitcast.100, %bitcast.106), window={size=2048 stride=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(jax_stft_vmap)/jit(main)/conv_general_dilated" source_file="/homes/ds011/code/main.py" source_line=22}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
JAX STFT: 18.0702 ms/it avg
PyTorch STFT: 2.4223 ms/it avg

This is on a Linux machine with a CUDA GPU (see below for specs), but the same happens when I run it on an M4 CPU (not using jax-metal). I am quite new to JAX - is this an issue with how I set up the script, or with the jax.scipy.signal.stft implementation?

Thank you for your work on maintaining the library!

System info (python version, jaxlib version, accelerator, etc.)

→ uv run --with torch --with "jax[cuda12]" python
Python 3.12.7 (main, Oct 16 2024, 04:37:19) [Clang 18.1.8 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.5.3
jaxlib: 0.5.3
numpy:  2.2.5
python: 3.12.7 (main, Oct 16 2024, 04:37:19) [Clang 18.1.8 ]
device info: NVIDIA RTX A5000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='wynne.eecs.qmul.ac.uk', release='6.8.0-54-generic', version='#56-Ubuntu SMP PREEMPT_DYNAMIC Sat Feb  8 00:37:57 UTC 2025', machine='x86_64')


$ nvidia-smi
Thu May  8 15:11:28 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A5000               Off |   00000000:1A:00.0 Off |                  Off |
| 30%   23C    P2             16W /  230W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |

<...snip...>

More GPUs are installed on the machine, but CUDA_VISIBLE_DEVICES=0 is set.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions