Skip to content

Is jax convolution really slow on Google Colaboratory? #7961

Answered by jakevdp
shailesh1729 asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - indeed it looks like JAX's convolution is quite a bit slower than numpy's for this case on CPU. I think there are probably two reasons for this:

  1. JAX convolution lowers to the ConvWithPadding operation in XLA, which is designed and optimized for the much more complicated batched multi-dimensional convolutions common in convolutional neural networks. I suspect that not much effort has been put into optimizing the simpler case of 1D convolutions.
  2. In general, there has been much more focus on optimizing operations in the XLA GPU backend than in the XLA CPU backend, so you occasionally come across examples where CPU execution is less optimized than we would like.

It would be worth openi…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@jakevdp
Comment options

@shailesh1729
Comment options

@shailesh1729
Comment options

Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants