-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax convolution significantly slower than scipy #5227
Comments
Indeed, JAX is ~200x slower than SciPy here. I think the difference is that SciPy supports multiple methods for implementing the convolution, which allows it to automatically switch to an asymptotically faster implementation based on FFTs when the convolution window is large: In contrast, JAX wrap's XLA's convolution. XLA is optimized for neural nets, which almost strictly use very small convolution windows (e.g., typically 3x3) so it only supports the equivalent of SciPy's "direct" method. It would be a very welcome improvement to add support for FFT-based convolutions to JAX! Side note: there are two useful tricks for profiling in JAX:
In this example, this looks like: @jax.jit
def convolve(x, y):
return jax.scipy.signal.convolve(x,x,mode='same')
convolve(x,x).block_until_ready()
%timeit convolve(x,x).block_until_ready() That said, none of these tricks mattered in this case. I still measure ~2 seconds! I think trick (1) may only matter on GPU/TPU, and in this case the computation is so slow that any overhead due to extra tracing/compilation is irrelevant. |
If this is on CPU, and with float64, the XLA convolution kernel might just be a really bad one. That is, I bet the f64 XLA:CPU convolution kernel isn't optimized for much of anything (whereas the f32 CPU ones at least have a chance of calling MKL DNN). This might be a similar story to FFT kernels on CPU. |
Interestingly, float32 on CPU seems to be significantly worse -- 30 seconds per loop! GPUs are faster (~170 ms/loop in Colab), but still much slower than SciPy. I think the best fix is adding the alternative FFT based convolutions. This could be done either in XLA or JAX, but is probably easier to implement in JAX. |
Thanks for the quick feedback! I built an FFT convolution package using autograd awhile back. It only supports 2d, but it's rather easy to generalize. The performance was the same as scipy/numpy, as expected (for larger arrays of similar size of course). I can throw a PR together if that's of interest. Does jax default to XLA's fft for all architectures? For CPU it might be nice to use the fftw library that comes bundled with NumPy/scipy. I also noticed that pocketfft was included in the source. |
Yes, a PR would be greatly appreciated! JAX uses pocketfft on CPU, which is faster and more accurate than XLA's FFT via eigen. On GPU and TPU, it uses XLA's FFT (which wraps cuFFT on GPU). |
See #2952 for discussion on FFT libraries. |
I took a stab at adding FFT-based convolutions in Jax at #6343. I would love to get some feedback. |
This is a CPU and a GPU issue. On a Colab T4 GPU, I get:
(edited: I forgot |
I noticed that doing a simple 2D convolution using Jax's scipy backend is significantly slower than using scipy itself:
Jax ends with
and scipy ends with
Is this expected?
The text was updated successfully, but these errors were encountered: