Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very slow JIT compilation due to constant-folding of very large BCOO sparse matrix nonzero entry array #14655

Open
aterenin opened this issue Feb 24, 2023 · 10 comments
Assignees
Labels
bug Something isn't working

Comments

@aterenin
Copy link

aterenin commented Feb 24, 2023

Description

I've got a use case where I'd like to store the nonzero entries of a very large sparse matrix, and then access them later during a machine learning training loop. Unfortunately, using JIT compilation results in constant-folding of this array, making it extremely slow on large problems. Here's an MWE that runs on my laptop and captures the typical behavior:

import jax
import jax.numpy as jnp
import jax.experimental.sparse as sparse
from jax.experimental.sparse import BCOO

n = 10000000

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

operator = build_sparse_linear_operator()

def fn(x):
    return operator(jnp.ones(n) / x).sum()

fn(1.0) # executes in 0.1s
jax.jit(fn)(1.0) # executes in almost one minute

Calling the function without JIT executes in about a tenth of a second, but calling it with JIT takes almost a minute. On larger problems in the codebase which prompted this MWE, I have had it crash due to running out of memory after about an hour. This produces warnings similar to the following:

Constant folding an instruction is taking > 8s:

  slice.22 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

The problem seems to be that the stored array, nonzeroes has shape (n,2), which in this case is very large, yet the JIT compiler tries to constant-fold it. This seems like a bug, unless there are good reasons why arrays with millions of elements should be constant-folded, in which case it would be very helpful to have some way of telling the compiler not to do so in this case.

What jax/jaxlib version are you using?

v0.4.4

Which accelerator(s) are you using?

N/A

Additional system info

N/A

NVIDIA GPU info

N/A

@aterenin aterenin added the bug Something isn't working label Feb 24, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 24, 2023

It seems like the compiler isn't making a great choice here with respect to the run-time/compile-time tradeoffs involved in constant folding. If you change your code slightly though, the problematic array will be computed at runtime:

def build_sparse_linear_operator():
    def product(other):
        nonzeroes = sparse.eye(n).indices # shape (n,2)
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

@aterenin
Copy link
Author

It seems like the compiler isn't making a great choice here with respect to the run-time/compile-time tradeoffs involved in constant folding. If you change your code slightly though, the problematic array will be computed at runtime:

def build_sparse_linear_operator():
    def product(other):
        nonzeroes = sparse.eye(n).indices # shape (n,2)
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

Thanks! Unfortunately, while this would avoid the problem here, I can't fix things upstream in that way - the array nonzeroes in my codebase is computed using what is effectively a black-box, CPU-only algorithm outside of JAX. Perhaps a better way to write the MWE would have been to use scipy.sparse.eye or similar instead.

Do you know if there are any workarounds I can implement to prevent the compiler from constant-folding the array?

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 24, 2023

I don't know... it's really an XLA bug, and I'm not sure of a way to change what XLA does here. Maybe you could rewrite your code so that nonzeros is explicitly passed to the jit-compiled outer function? I know that's probably not the answer you're looking for, but I think it would work...

@aterenin
Copy link
Author

aterenin commented Mar 2, 2023

Thanks! Unfortunately, passing around nonzeroes won't work, since this is something computed by the package which should not be exposed to the user, and the functions it would need to be passed into are called by the user.

Should I file a bug report upstream? Would the TensorFlow repository be the right place?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 2, 2023

It looks like one possible workaround for now is to use an optimization barrier:

from jax._src.ad_checkpoint import _optimization_barrier

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        nz = _optimization_barrier(nonzeroes)
        matrix = BCOO((jnp.ones(n), nz), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

This is still somewhat experimental, so unfortunately there is no public API for this.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 2, 2023

That said, it's probably worth filing an XLA bug for this. It should be something that the compiler handles automatically: https://github.com/openxla/xla

@aterenin
Copy link
Author

aterenin commented Mar 3, 2023

Thanks, that works! Two comments:

  1. The function _optimization_barrier must be called inside product, and not outside of it, so for instance nonzeroes = _optimization_barrier(sparse.eye(n).indices) will not work.
  2. This XLA bug might be specific to integer arrays: in my upstream codebase, I have other arrays which are also precomputed, but wrapping BCOO nonzero index arrays in _optimization_barrier is sufficient to get JIT to not freeze.

Very much appreciate your help with this!

@mjsML
Copy link
Collaborator

mjsML commented Mar 24, 2023

@aterenin, even though it seems like an XLA bug, it would be helpful to mention which hardware you are using (compilers have different backends, so code paths are different).

@aterenin
Copy link
Author

@aterenin, even though it seems like an XLA bug, it would be helpful to mention which hardware you are using (compilers have different backends, so code paths are different).

Sure! Have reproduced this issue on both Nvidia GPU and Apple M1.

@watkinrt
Copy link

I know this thread is about a year old, but I thought I would note that I've run into similar issues with JITing large sparse arrays (in my case, for moderate sized finite element simulations - ~2M elements). Beyond slow constant folding, I've also run into the limit where XLA seg faults if the array is too large (generally somewhere around > 400,000,000 non-zero elements in the array). In this case, _optimization_barrier has no effect for me and the only solution is to provide the sparse array indices as an input to my function. I've run into this issue on both Linux and Windows. Due to the size of the arrays, I've only been able to try this on the CPU (as my GPUs don't have enough memory for problems this large).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants