Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to efficiently lookup a value from a BCOO matrix? #13628

Closed
lintangsutawika opened this issue Dec 13, 2022 · 5 comments
Closed

How to efficiently lookup a value from a BCOO matrix? #13628

lintangsutawika opened this issue Dec 13, 2022 · 5 comments
Assignees
Labels
needs info More information is required to diagnose & prioritize the issue.

Comments

@lintangsutawika
Copy link

lintangsutawika commented Dec 13, 2022

Description

What is the best way to look up a value that is stored in a BCOO sparse matrix? I tried two methods, (1) indexing directly given a known coordinate set and (2) make an another spares matrix with value of 1 in the coordinate of interest and multiply that matrix to the matrix of interest then sum to get the final value.

import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO

n_dim = 5
n_array = 1_000_000
num_element = 10
data = jax.random.randint(jax.random.PRNGKey(0), (num_element,), 0, n_array)
indices = jax.random.randint(jax.random.PRNGKey(0), (num_element, n_dim), 0, n_array)

m = BCOO((data, indices), shape=tuple([n_array]*n_dim))

def get_by_indexing(m, indices):

    i0, i1, i2, i3, i4 = indices

    return m[i0, i1, i2, i3, i4].todense()

def get_by_multiplication(m, indices):

    identity = jnp.array([1])
    x = BCOO((identity, jnp.expand_dims(indices, 0)), shape=m.shape)

    return jax.experimental.sparse.bcoo_multiply_sparse(m, x).sum()

Getting value through indexing

In   [2]: %timeit get_by_indexing(m, indices[0])
55.3 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Getting value through multiplication

In   [3]: %timeit get_by_multiplication(m, indices[0])
8 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Also just to be sure that both function returns the same value.

In   [4]: get_by_indexing(m, indices[0]) == get_by_multiplication(m, indices[0])
Out[4]: DeviceArray(True, dtype=bool)

It looks like that getting the value by multiplication is faster than indexing. Shouldn't usual indexing be faster?

What jax/jaxlib version are you using?

jax v0.3.25, jaxlib v0.3.22

Which accelerator(s) are you using?

CPU

Additional system info

No response

NVIDIA GPU info

No response

@lintangsutawika lintangsutawika added the bug Something isn't working label Dec 13, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 13, 2022

Hi - thanks for the question. In general, the answer for how to make operations faster in jax is to use jit compilation. This is especially true for sparse operations, which are generally implemented not in terms of a single efficient XLA op, but rather a sequence of XLA operations on the underlying dense buffers. If you jit-compile your two functions, you'll see that the timings are much faster, and that the two methods are comparable.

Here are the results on a Colab CPU runtime:

get_by_indexing_jit = jax.jit(get_by_indexing)
get_by_multiplication_jit = jax.jit(get_by_multiplication)

_ = get_by_indexing_jit(m, indices[0])
_ = get_by_multiplication_jit(m, indices[0])

%timeit get_by_indexing(m, indices[0]).block_until_ready()
# 89.5 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit get_by_multiplication(m, indices[0]).block_until_ready()
# 14 ms ± 505 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit get_by_indexing_jit(m, indices[0]).block_until_ready()
# 511 µs ± 17.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit get_by_multiplication_jit(m, indices[0]).block_until_ready()
# 504 µs ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

With this in mind, I'd suggest using a straightforward indexing operation to access elements of the sparse array, rather than using a more complicated matmul-based approach.

@jakevdp jakevdp self-assigned this Dec 13, 2022
@jakevdp jakevdp added question Questions for the JAX team and removed bug Something isn't working labels Dec 13, 2022
@lintangsutawika
Copy link
Author

Thanks!

I noticed this may work for sparse matrix with small numbers of nse. I intend to use large sparse matrix with a large nse.

BCOO(int32[1000000, 1000000, 1000000, 1000000, 1000000], nse=44896321)

rerunning the indexing functions return quite different results.

%timeit get_by_indexing(m, indices[0]).block_until_ready()
# 22 s ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit get_by_multiplication(m, indices[0]).block_until_ready()
# 5.9 s ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit get_by_indexing_jit(m, indices[0]).block_until_ready()
# 14.9 s ± 8.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit get_by_multiplication_jit(m, indices[0]).block_until_ready()
# 1.23 s ± 772 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Is the speed of indexing inversely proportional to the number of nse in the sparse matrix?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 14, 2022

Yes, indexing a sparse data structure involves searching through the index buffer for the requested index, so in general as the size of the index buffer grows, the indexing operation will slow down.

JAX's BCOO objects are particularly unsuited for large nse compared to other implementations, because XLA does not offer any efficient binary search primitive, and so these index searches scale linearly with the size of the buffers (i.e. nse) whereas a binary-search-based approach could in theory scale as log(nse).

@lintangsutawika
Copy link
Author

lintangsutawika commented Dec 21, 2022

I'd like to extend this so that I can process a batch of indices.

This process works. Which allows a number of batch of indices to be processed.

get_index_fn = partial(util.get_by_multiplication_jit, matrix=matrix)
vmap_get_index = jax.vmap(
    get_index_fn,
    in_axes=(0)
    )

If I set a batch too high it would crash due to OOM, which I solved with using jax.lax.map

def process_index(indices):
    return jax.lax.map(vmap_get_index, indices)

Now I want to expand the process to be able to run parallel in 8 GPUs.

The solution I made works but takes forever to start. I wander if this is the correct way to use pmap to divide the work to each GPU?

pmap_process_index = jax.pmap(
    process_index,
    in_axes=(0)
    )

Also, for input shape, I've made it so that the dimensions are (num_devices, num of sequence to map, vmap batch size, n_dim)

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 27, 2022

Sorry, I'm not sure what your question is. Perhaps it would help to add a fully reproducible example of what you're doing, along with pointing out where the problem occurs.

@jakevdp jakevdp added needs info More information is required to diagnose & prioritize the issue. and removed question Questions for the JAX team labels Dec 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs info More information is required to diagnose & prioritize the issue.
Projects
None yet
Development

No branches or pull requests

2 participants