Is jax convolution really slow on Google Colaboratory? #7961
-
Hi, Here is my sample run with CPU runtime on Google Colab. The basic convolution seems way too slow compared to NumPy. With the GPU runtime (see below), JAX appears to be just about 30-40% faster compared to NumPy. I am assuming that NumPy is still running on CPU and JAX is running on GPU when the runtime is GPU. Is there something wrong I am doing? Do we need to configure something so that convolution can be comparable to NumPy on CPU and much faster on GPU? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi - indeed it looks like JAX's convolution is quite a bit slower than numpy's for this case on CPU. I think there are probably two reasons for this:
It would be worth opening a bug about this, because I think we could probably do better in this case. |
Beta Was this translation helpful? Give feedback.
Hi - indeed it looks like JAX's convolution is quite a bit slower than numpy's for this case on CPU. I think there are probably two reasons for this:
ConvWithPadding
operation in XLA, which is designed and optimized for the much more complicated batched multi-dimensional convolutions common in convolutional neural networks. I suspect that not much effort has been put into optimizing the simpler case of 1D convolutions.It would be worth openi…