jnp.split()
has long compile times for large numbers of splits
#12999
Labels
performance
make things lean and fast
Reported in https://stackoverflow.com/q/74199437
Subsequent calls are much faster, once the
XLA:Slice
operations have been cached:There are some tradeoffs here that make the solution to this non-trivial (see my answer in the above StackOverflow post), but it might be better to use
lax.dynamic_slice
in place oflax.slice
to prevent this kind of catastrophically-bad outcome.The text was updated successfully, but these errors were encountered: