Skip to content

Linear/Dense performance for PyTorch vs JAX (flax/stax) #8497

Answered by PhilipVinc
Huizerd asked this question in Q&A
Discussion options

You must be logged in to vote

for a proper comparison you should pre-heat the jit.
So you should call the jitted function once before starting the timer, so that you don't profile jit time.

For a fair comparison you should also feed jnp.arrays and not numpy arrays to jax.

Replies: 1 comment 10 replies

Comment options

You must be logged in to vote
10 replies
@sourabh2k15
Comment options

@Huizerd
Comment options

@Habush
Comment options

@mattjj
Comment options

@Habush
Comment options

Answer selected by Huizerd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
6 participants