Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax convolution significantly slower than scipy #5227

Open
smartalecH opened this issue Dec 18, 2020 · 8 comments
Open

Jax convolution significantly slower than scipy #5227

smartalecH opened this issue Dec 18, 2020 · 8 comments
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. CPU Issues related to the CPU compiler/runtime enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs open Issues intentionally left open, with no schedule for next steps. P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) XLA

Comments

@smartalecH
Copy link

I noticed that doing a simple 2D convolution using Jax's scipy backend is significantly slower than using scipy itself:

import numpy as np
import jax.scipy.signal
import scipy.signal

import jax.config
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

np.random.seed(1)
x  = np.random.rand(300,300)

%timeit jax.scipy.signal.convolve(x,x,mode='same')

%timeit scipy.signal.convolve(x,x,mode='same')

Jax ends with

2.15 s ± 50.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

and scipy ends with

12.6 ms ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Is this expected?

@shoyer
Copy link
Collaborator

shoyer commented Dec 18, 2020

Indeed, JAX is ~200x slower than SciPy here.

I think the difference is that SciPy supports multiple methods for implementing the convolution, which allows it to automatically switch to an asymptotically faster implementation based on FFTs when the convolution window is large:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve.html

In contrast, JAX wrap's XLA's convolution. XLA is optimized for neural nets, which almost strictly use very small convolution windows (e.g., typically 3x3) so it only supports the equivalent of SciPy's "direct" method.

It would be a very welcome improvement to add support for FFT-based convolutions to JAX!


Side note: there are two useful tricks for profiling in JAX:

  1. Use block_until_ready() to ensure you wait until the computation is done
  2. Wrap the computation in jit and call it once to ensure it is compiled before timing

In this example, this looks like:

@jax.jit
def convolve(x, y):
  return jax.scipy.signal.convolve(x,x,mode='same')

convolve(x,x).block_until_ready()
%timeit convolve(x,x).block_until_ready()

That said, none of these tricks mattered in this case. I still measure ~2 seconds! I think trick (1) may only matter on GPU/TPU, and in this case the computation is so slow that any overhead due to extra tracing/compilation is irrelevant.

@shoyer shoyer added contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request labels Dec 18, 2020
@mattjj
Copy link
Collaborator

mattjj commented Dec 18, 2020

If this is on CPU, and with float64, the XLA convolution kernel might just be a really bad one. That is, I bet the f64 XLA:CPU convolution kernel isn't optimized for much of anything (whereas the f32 CPU ones at least have a chance of calling MKL DNN). This might be a similar story to FFT kernels on CPU.

@shoyer
Copy link
Collaborator

shoyer commented Dec 19, 2020

Interestingly, float32 on CPU seems to be significantly worse -- 30 seconds per loop!

GPUs are faster (~170 ms/loop in Colab), but still much slower than SciPy.

I think the best fix is adding the alternative FFT based convolutions. This could be done either in XLA or JAX, but is probably easier to implement in JAX.

@smartalecH
Copy link
Author

Thanks for the quick feedback!

I built an FFT convolution package using autograd awhile back. It only supports 2d, but it's rather easy to generalize. The performance was the same as scipy/numpy, as expected (for larger arrays of similar size of course).

I can throw a PR together if that's of interest.

Does jax default to XLA's fft for all architectures? For CPU it might be nice to use the fftw library that comes bundled with NumPy/scipy. I also noticed that pocketfft was included in the source.

@shoyer
Copy link
Collaborator

shoyer commented Dec 19, 2020

Yes, a PR would be greatly appreciated!

JAX uses pocketfft on CPU, which is faster and more accurate than XLA's FFT via eigen. On GPU and TPU, it uses XLA's FFT (which wraps cuFFT on GPU).

@shoyer
Copy link
Collaborator

shoyer commented Dec 19, 2020

See #2952 for discussion on FFT libraries.

@zhangqiaorjc zhangqiaorjc added open Issues intentionally left open, with no schedule for next steps. XLA labels Jan 12, 2021
@peterroelants
Copy link

It would be a very welcome improvement to add support for FFT-based convolutions to JAX!

I took a stab at adding FFT-based convolutions in Jax at #6343. I would love to get some feedback.

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) labels Aug 10, 2022
@hawkinsp hawkinsp added CPU Issues related to the CPU compiler/runtime and removed NVIDIA GPU Issues specific to NVIDIA GPUs labels Aug 12, 2022
@hawkinsp
Copy link
Collaborator

hawkinsp commented Aug 12, 2022

This is a CPU and a GPU issue.

On a Colab T4 GPU, I get:

%timeit jax.scipy.signal.convolve(x,x,mode='same').block_until_ready()
1 loop, best of 5: 104 ms per loop
%timeit scipy.signal.convolve(x,x,mode='same')
100 loops, best of 5: 11.2 ms per loop

(edited: I forgot .block_until_ready(), without which the timing is invalid.)

@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Aug 12, 2022
@sudhakarsingh27 sudhakarsingh27 added P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) and removed P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) labels Aug 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. CPU Issues related to the CPU compiler/runtime enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs open Issues intentionally left open, with no schedule for next steps. P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) XLA
Projects
None yet
Development

No branches or pull requests

7 participants