-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Very slow JIT compilation due to constant-folding of very large BCOO sparse matrix nonzero entry array #14655
Comments
It seems like the compiler isn't making a great choice here with respect to the run-time/compile-time tradeoffs involved in constant folding. If you change your code slightly though, the problematic array will be computed at runtime: def build_sparse_linear_operator():
def product(other):
nonzeroes = sparse.eye(n).indices # shape (n,2)
matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
return matrix @ other
return product |
Thanks! Unfortunately, while this would avoid the problem here, I can't fix things upstream in that way - the array Do you know if there are any workarounds I can implement to prevent the compiler from constant-folding the array? |
I don't know... it's really an XLA bug, and I'm not sure of a way to change what XLA does here. Maybe you could rewrite your code so that |
Thanks! Unfortunately, passing around Should I file a bug report upstream? Would the TensorFlow repository be the right place? |
It looks like one possible workaround for now is to use an optimization barrier: from jax._src.ad_checkpoint import _optimization_barrier
def build_sparse_linear_operator():
nonzeroes = sparse.eye(n).indices # shape (n,2)
def product(other):
nz = _optimization_barrier(nonzeroes)
matrix = BCOO((jnp.ones(n), nz), shape=(n,n), indices_sorted=True, unique_indices=True)
return matrix @ other
return product This is still somewhat experimental, so unfortunately there is no public API for this. |
That said, it's probably worth filing an XLA bug for this. It should be something that the compiler handles automatically: https://github.com/openxla/xla |
Thanks, that works! Two comments:
Very much appreciate your help with this! |
@aterenin, even though it seems like an XLA bug, it would be helpful to mention which hardware you are using (compilers have different backends, so code paths are different). |
Sure! Have reproduced this issue on both Nvidia GPU and Apple M1. |
I know this thread is about a year old, but I thought I would note that I've run into similar issues with JITing large sparse arrays (in my case, for moderate sized finite element simulations - ~2M elements). Beyond slow constant folding, I've also run into the limit where XLA seg faults if the array is too large (generally somewhere around > 400,000,000 non-zero elements in the array). In this case, _optimization_barrier has no effect for me and the only solution is to provide the sparse array indices as an input to my function. I've run into this issue on both Linux and Windows. Due to the size of the arrays, I've only been able to try this on the CPU (as my GPUs don't have enough memory for problems this large). |
Description
I've got a use case where I'd like to store the nonzero entries of a very large sparse matrix, and then access them later during a machine learning training loop. Unfortunately, using JIT compilation results in constant-folding of this array, making it extremely slow on large problems. Here's an MWE that runs on my laptop and captures the typical behavior:
Calling the function without JIT executes in about a tenth of a second, but calling it with JIT takes almost a minute. On larger problems in the codebase which prompted this MWE, I have had it crash due to running out of memory after about an hour. This produces warnings similar to the following:
The problem seems to be that the stored array,
nonzeroes
has shape(n,2)
, which in this case is very large, yet the JIT compiler tries to constant-fold it. This seems like a bug, unless there are good reasons why arrays with millions of elements should be constant-folded, in which case it would be very helpful to have some way of telling the compiler not to do so in this case.What jax/jaxlib version are you using?
v0.4.4
Which accelerator(s) are you using?
N/A
Additional system info
N/A
NVIDIA GPU info
N/A
The text was updated successfully, but these errors were encountered: