-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EXLA FFT much slower than Torchx on CPU #1234
Comments
Yes we can do something similar, I will look into it |
JAX now uses ducc fft per: google/jax#12122 ducc is the successor to pocketfft, so we can mirror this as well. The challenge is determining where/how to build the third party repo into EXLA. My opinion is we can add the ducc specific library code upstream to the XLA precompiled library. It will require some changes to the build process (and I might have to patch tensorflow :/) Implementing the custom operator once we figure out the dependency stuff is easy |
@seanmor5 I agree with adding the custom FFT code to the precompiled XLA library. Why do you think we'd have to patch TF? Can't we use an XLA::CustomCall to the job? |
Good news: google/jax@390022a This will be solved in future XLA versions. :) |
Here are the results with EXLA main and the new XLA (run on M1):
|
YAAAAAAAAAS |
Sweet! Should speed up the Bumblebee whisper model too, last I checked that was using size 400 FFTs. |
From the discussion elixir-nx/nx_signal#14, it looks like the EXLA FFT implementation is ~7x slower than Torchx when run on CPU.
tf.signal
has run into the same limitations tensorflow/tensorflow#6541 and Jax has worked around it by using PocketFFT: google/jax#2952. Would EXLA be able to do something similar?Here's the (modified) benchmark:
results:
The text was updated successfully, but these errors were encountered: