-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FFT precision/performance #2952
Comments
Please see the related discussion in #2874. |
IIUC #2874 is only about a GPU bug. Perhaps this issue is mainly about CPU, both performance and correctness issues. I think XLA:CPU is using Eigen's FFT. Maybe it's slow or something. I'll ping the XLA:CPU folks to see if they know anything about it. On the JAX side, we could possibly do a CustomCall into some other FFT implementation on CPU, like we CustomCall into LAPACK kernels for matrix decompositions. |
I confirmed with the XLA:CPU folks that XLA is just calling into Eigen here, and "it's possible but unlikely that XLA is doing something bad here that triggers slowness." (I'd like to double-check just by executing this benchmark that Eigen is being multithreaded properly for FFTs, but I'm not sure when I'll get a chance to do that.) Depending on whether XLA:CPU folks have the bandwidth to improve this, we might want to look into JAX-side solutions. I'll update this thread again when I learn more. |
I am working on implementing a Jax backend for kymatio (kymat.io), a Python package implementing the scattering transform. When we compare the Jax FFT implementation against a closed-form expression of the discrete Fourier transform of a box of ones with Are there plans to address this? PR #3290 does not appear to have solved this issue.
Even with smaller arrays we note large differences.
|
What version of
|
Hi, I was using The input was a box of ones as float 32s. Testing with input as a box of float 64s, we obtain similar inaccuracies. This is with version 0.1.48, using
|
Out of curiosity, do you see the same results from TensorFlow? JAX uses Eigen for its FFT implementation on CPU, as does TensorFlow, so one hypothesis is that this is simply due to the quality of the Eigen implementation. That might be nice to verify, if you have time. If they did differ that would be very interesting to know. |
It appears that they give the same outputs on both CPU and GPU! Edit: My interpretation was that Jax is supposed to be as similar as possible to Numpy. Is this interpretation wrong? |
"Shouldn't Jax be closer to Numpy?" Ultimately they are two different pieces of code and they will not act the same in all circumstances. And it's not a goal to precisely match NumPy everywhere. There are at least three things you could mean: For (a): perhaps. We don't try to follow NumPy precisely, and in a number of cases we default to float32 to be more GPU friendly. For (b): I suspect we need to find a higher quality implementation of FFT on CPU. The obvious candidate is probably Intel's MKL library. For (c): JAX uses completely different FFT implementations on CPU and GPU. On GPU it uses cufft (which pretty much everyone uses as far as I am aware). I would actually expect that you would see high quality results on GPU. Can you confirm that you were actually running in 64-bit mode on GPU? |
The other results were on CPU, both 32 and 64 bit. Looks like 64 bit on GPU match up.
|
Thank you for looking into this, @hawkinsp. From our perspective the best thing would be to have the Jax FFT be more accurate (comparable to NumPy, SciPy, and PyFFTW) on the CPU (for both float32 and float64). Plugging into MKL, as you suggest, might be a good idea here. If this is the Eigen FFT interface that Jax is using, it looks like it supports switching from the default backend (kissfft) to FFTW should be possible by setting a compiler flag. |
Licensing might be the trickiest part here. FFTW is GPL and MKL is proprietary. NumPy uses pocketfft these days. Writing a custom call in JAX to use pocketfft on CPU could be a good option -- or perhaps XLA CPU should use pocketfft. |
I can also add that the radio astronomy community would be greatly interested if JAX fft on CPU would be both accurate and fast. @mattjj Re: is the result the same as with tensorflow? Yes, With JAX:
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
With Tensorflow:
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 44.88363027572632
jax fft execution time [ms]: 84.56079244613647
tensorflow fft execution time [ms]: 84.12498950958252 |
@mattjj Has there been any progress in understanding why Jax's fft is ~twice as slow as NumPy on a CPU? I second @Joshuaalbert comment stating that other divisions of astrophysics would also be very interested in a fast and accurate Jax FFT. |
I think we're pretty clear on what to do here: replace the Eigen FFT on CPU with something else, probably PocketFFT, same as NumPy. We just need someone to actually do it! |
@hawkinsp What about MKL's FFT? It's the fastest that I've seen. FFTW is currently what radio astronomers use, due to it's popularity. Some informative comparison of FFTs is here: https://github.com/project-gemmi/benchmarking-fft/ |
I think optionally using MKL could be viable, but MKL is closed-source software. At the very least, we want to preserve an open source option. |
And FFTW is GPL :/ I think the main limiting factor here is just developer bandwidth on the JAX core team, where we have to balance a lot of considerations (code licenses, ensuring it works in OSS as well as internally at Google, etc). Until we improve this, it might be useful to look at how you can rig up a call into any implementation you want by registering a custom backend-specific kernel with XLA. One example of how to do that is mpi4jax. You could also look at how JAX calls into LAPACK on CPU and cuSolver on GPU, e.g. starting at lapax.pyx for the CPU stuff. |
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.) For the benchmark in #2952 on my workstation: Before: ``` 907.3490574884647 max: 4.362646594533903e-08 mean: 6.237288307614869e-09 min: 0.0 numpy fft execution time [ms]: 37.088446617126465 jax fft execution time [ms]: 74.93342399597168 ``` After: ``` 907.3490574884647 max: 1.9057386696477137e-12 mean: 3.9326737908882566e-13 min: 0.0 numpy fft execution time [ms]: 37.756404876708984 jax fft execution time [ms]: 28.128278255462646 ``` Fixes #2952 PiperOrigin-RevId: 338530400
This issue should be fixed, but it requires a jaxlib rebuild. You can either build from source or wait for us to make a new jaxlib release. Hope that helps! |
The updated FFT has been released as part of jaxlib 0.1.57. Hope that helps! |
I find the noticeably difference between outputs of
numpy.fft.fft
andjax.numpy.fft.fft
.The difference also changes with different device. For
cpu
device error is bigger than forgpu
device. On the other hand mean absolute error forgpu
implementation offft
from e.g. PyTorch is around 1e-8 which seems reasonably. I guess that might be some minor bug.Second issue is the performance of the
jax.numpy.fft.fft
oncpu
device. I am aware thatjax
is intended for GPU/TPU, but the overhead of jax forfft
usingcpu
seems weirdly big.Below is the simple code for reproduction.
The text was updated successfully, but these errors were encountered: