-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jnp.split: use dynamic rather than static slices for speed #13096
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want benchmarks here, like those you added in #12219?
Does this change affect performance for a small number of splits?
Good question - for repeated operations after the initial run, we get something like this: main branch: In [2]: %timeit jax.block_until_ready(jnp.split(x, jnp.arange(0, 5000, 10)))
40.5 ms ± 4.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) this branch: In [2]: %timeit jax.block_until_ready(jnp.split(x, jnp.arange(0, 5000, 10)))
116 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) So a slowdown of ~3x on repeated calls is the tradeoff for saving ~150x on the initial call. What do you think? |
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) | ||
# Use dynamic rather than static slice to prevent slow execution of large | ||
# number of splits; see https://github.com/google/jax/issues/12999 | ||
return [lax.dynamic_slice(ary, _subval(starts, axis, start), _subval(sizes, axis, end - start)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option here: we could replace this with
return [ary[int(start): int(end)] for start, end in zip(split_indices[:-1], split_indices[1:])]
Due to #12219, this would now use dynamic rather than static slices and result in the same performance characteristics. I kind of like the idea of delegating the performance question to existing code. What do you think?
From conversation with @hawkinsp: we may want to profile in order to understand why the dynamic slice approach is 3x slower, since it isn't clear that it ought to be (?) |
Why is |
|
Thanks! That explains the performance issues I've been trying to debug with slice vs dynamic_slice vs numpy slice. |
I can no longer reproduce these performance issues, despite this still lowering to static slice. |
Fixes #12999. See also #9445
This is very similar to the issue & fix in #12219
main branch:
this branch: