Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.split() has long compile times for large numbers of splits #12999

Closed
jakevdp opened this issue Oct 26, 2022 · 1 comment
Closed

jnp.split() has long compile times for large numbers of splits #12999

jakevdp opened this issue Oct 26, 2022 · 1 comment
Assignees
Labels
performance make things lean and fast

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 26, 2022

Reported in https://stackoverflow.com/q/74199437

import jax.numpy as jnp
x = jnp.ones(5000)
%time _ = jnp.split(x, jnp.arange(0, 5000, 10))
# CPU times: user 7.29 s, sys: 101 ms, total: 7.39 s
# Wall time: 7.46 s

Subsequent calls are much faster, once the XLA:Slice operations have been cached:

%time _ = jnp.split(x, jnp.arange(0, 5000, 10))
# CPU times: user 61.4 ms, sys: 0 ns, total: 61.4 ms
# Wall time: 70.3 ms

There are some tradeoffs here that make the solution to this non-trivial (see my answer in the above StackOverflow post), but it might be better to use lax.dynamic_slice in place of lax.slice to prevent this kind of catastrophically-bad outcome.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 3, 2023

Closing as duplicate of #9445

@jakevdp jakevdp closed this as completed Nov 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance make things lean and fast
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant