-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Hello! Out of curiosity, I ran the following benchmark script to compare STFT performance between JAX and PyTorch:
import torch
import jax
import jax.numpy as jnp
from time import perf_counter
from functools import partial
n_samples = 100000
batch_size = 256
torch_device = 'cuda'
n_fft = 2048
torch_window = torch.hann_window(n_fft, device=torch_device)
jax_window = jax.numpy.hanning(n_fft)
def torch_stft(x):
stft = torch.stft(x, n_fft=n_fft, hop_length=n_fft//2, window=torch_window, center=True, pad_mode='constant', return_complex=True)
return stft.abs().mean()
def jax_stft(x):
_, _, stft = jax.scipy.signal.stft(x, window=jax_window, nperseg=n_fft, noverlap=n_fft//2)
return jnp.abs(stft)
@jax.jit
def jax_stft_vmap(x):
return jax.vmap(jax_stft)(x).mean()
n_loops = 1000
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
jax_noise = jax.random.uniform(subkey, (batch_size, n_samples))
jax_stft_vmap(jax_noise).block_until_ready()
start = perf_counter()
for _ in range(n_loops):
jax_stft_vmap(jax_noise)
end = perf_counter()
print(f"JAX STFT: {1000 * (end - start) / n_loops:.4f} ms/it avg")
torch_noise = torch.rand((batch_size, n_samples), device=torch_device)
start = perf_counter()
for _ in range(n_loops):
torch_stft(torch_noise)
end = perf_counter()
print(f"PyTorch STFT: {1000 * (end - start) / n_loops:.4f} ms/it avg")I was suprised by the output; it measures the JAX STFT to be significantly slower, but also explicitly warns about an operation seemingly taking longer than expected:
→ uv run --with torch --with "jax[cuda12]" main.py
2025-05-08 15:23:34.079270: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng0{} for conv %cudnn-conv.1 = (f32[256,2048,99]{2,1,0}, u8[0]{0}) custom-call(%bitcast.100, %bitcast.106), window={size=2048 stride=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(jax_stft_vmap)/jit(main)/conv_general_dilated" source_file="/homes/ds011/code/main.py" source_line=22}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-05-08 15:23:35.866923: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 2.787862855s
Trying algorithm eng0{} for conv %cudnn-conv.1 = (f32[256,2048,99]{2,1,0}, u8[0]{0}) custom-call(%bitcast.100, %bitcast.106), window={size=2048 stride=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(jax_stft_vmap)/jit(main)/conv_general_dilated" source_file="/homes/ds011/code/main.py" source_line=22}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
JAX STFT: 18.0702 ms/it avg
PyTorch STFT: 2.4223 ms/it avg
This is on a Linux machine with a CUDA GPU (see below for specs), but the same happens when I run it on an M4 CPU (not using jax-metal). I am quite new to JAX - is this an issue with how I set up the script, or with the jax.scipy.signal.stft implementation?
Thank you for your work on maintaining the library!
System info (python version, jaxlib version, accelerator, etc.)
→ uv run --with torch --with "jax[cuda12]" python
Python 3.12.7 (main, Oct 16 2024, 04:37:19) [Clang 18.1.8 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax: 0.5.3
jaxlib: 0.5.3
numpy: 2.2.5
python: 3.12.7 (main, Oct 16 2024, 04:37:19) [Clang 18.1.8 ]
device info: NVIDIA RTX A5000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='wynne.eecs.qmul.ac.uk', release='6.8.0-54-generic', version='#56-Ubuntu SMP PREEMPT_DYNAMIC Sat Feb 8 00:37:57 UTC 2025', machine='x86_64')
$ nvidia-smi
Thu May 8 15:11:28 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02 Driver Version: 550.107.02 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A5000 Off | 00000000:1A:00.0 Off | Off |
| 30% 23C P2 16W / 230W | 4MiB / 24564MiB | 0% Default |
| | | N/A |
<...snip...>
More GPUs are installed on the machine, but CUDA_VISIBLE_DEVICES=0 is set.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working