Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.split: use dynamic rather than static slices for speed #13096

Closed
wants to merge 1 commit into from

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Nov 3, 2022

Fixes #12999. See also #9445

This is very similar to the issue & fix in #12219

main branch:

In [1]: import jax.numpy as jnp 
   ...: x = jnp.ones(5000) 
   ...: %time _ = jnp.split(x, jnp.arange(0, 5000, 10))                                                                   
CPU times: user 4.77 s, sys: 118 ms, total: 4.89 s
Wall time: 4.96 s

this branch:

In [1]: import jax.numpy as jnp 
   ...: x = jnp.ones(5000) 
   ...: %time _ = jnp.split(x, jnp.arange(0, 5000, 10))                                                                   
CPU times: user 342 ms, sys: 8.52 ms, total: 351 ms
Wall time: 360 ms

@jakevdp jakevdp requested a review from froystig November 3, 2022 19:51
@jakevdp jakevdp self-assigned this Nov 3, 2022
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want benchmarks here, like those you added in #12219?

Does this change affect performance for a small number of splits?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 4, 2022

Good question - for repeated operations after the initial run, we get something like this:

main branch:

In [2]: %timeit jax.block_until_ready(jnp.split(x, jnp.arange(0, 5000, 10)))                                          
40.5 ms ± 4.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

this branch:

In [2]: %timeit jax.block_until_ready(jnp.split(x, jnp.arange(0, 5000, 10)))                                          
116 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

So a slowdown of ~3x on repeated calls is the tradeoff for saving ~150x on the initial call.

What do you think?

return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
# Use dynamic rather than static slice to prevent slow execution of large
# number of splits; see https://github.com/google/jax/issues/12999
return [lax.dynamic_slice(ary, _subval(starts, axis, start), _subval(sizes, axis, end - start))
Copy link
Collaborator Author

@jakevdp jakevdp Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option here: we could replace this with

return [ary[int(start): int(end)] for start, end in zip(split_indices[:-1], split_indices[1:])]

Due to #12219, this would now use dynamic rather than static slices and result in the same performance characteristics. I kind of like the idea of delegating the performance question to existing code. What do you think?

@froystig
Copy link
Member

From conversation with @hawkinsp: we may want to profile in order to understand why the dynamic slice approach is 3x slower, since it isn't clear that it ought to be (?)

@KeAWang
Copy link

KeAWang commented Dec 8, 2022

Why is dynamic_slice faster than static slice in the first place?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Dec 8, 2022

Why is dynamic_slice faster than static slice in the first place?

dynamic_slice is slower than static slice if you call it once, because it does not specialize its code on index values, and so at runtime XLA has to perform some logic regarding the value of the start index. But static slice is specialized on static start indices, so each call with different start indices incurs a small overhead in XLA at compile time, and when you accumulate this small overhead thousands of times, it is slower than dynamic_slice, which does not have such overhead because it is not specialized on the index values.

@KeAWang
Copy link

KeAWang commented Dec 8, 2022

Thanks! That explains the performance issues I've been trying to debug with slice vs dynamic_slice vs numpy slice.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 3, 2023

I can no longer reproduce these performance issues, despite this still lowering to static slice.

@jakevdp jakevdp closed this Nov 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jnp.split() has long compile times for large numbers of splits
3 participants