-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to efficiently lookup a value from a BCOO matrix? #13628
Comments
Hi - thanks for the question. In general, the answer for how to make operations faster in jax is to use Here are the results on a Colab CPU runtime: get_by_indexing_jit = jax.jit(get_by_indexing)
get_by_multiplication_jit = jax.jit(get_by_multiplication)
_ = get_by_indexing_jit(m, indices[0])
_ = get_by_multiplication_jit(m, indices[0])
%timeit get_by_indexing(m, indices[0]).block_until_ready()
# 89.5 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit get_by_multiplication(m, indices[0]).block_until_ready()
# 14 ms ± 505 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit get_by_indexing_jit(m, indices[0]).block_until_ready()
# 511 µs ± 17.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit get_by_multiplication_jit(m, indices[0]).block_until_ready()
# 504 µs ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) With this in mind, I'd suggest using a straightforward indexing operation to access elements of the sparse array, rather than using a more complicated matmul-based approach. |
Thanks! I noticed this may work for sparse matrix with small numbers of BCOO(int32[1000000, 1000000, 1000000, 1000000, 1000000], nse=44896321) rerunning the indexing functions return quite different results. %timeit get_by_indexing(m, indices[0]).block_until_ready()
# 22 s ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit get_by_multiplication(m, indices[0]).block_until_ready()
# 5.9 s ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit get_by_indexing_jit(m, indices[0]).block_until_ready()
# 14.9 s ± 8.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit get_by_multiplication_jit(m, indices[0]).block_until_ready()
# 1.23 s ± 772 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) Is the speed of indexing inversely proportional to the number of |
Yes, indexing a sparse data structure involves searching through the index buffer for the requested index, so in general as the size of the index buffer grows, the indexing operation will slow down. JAX's |
I'd like to extend this so that I can process a batch of indices. This process works. Which allows a number of batch of indices to be processed. get_index_fn = partial(util.get_by_multiplication_jit, matrix=matrix)
vmap_get_index = jax.vmap(
get_index_fn,
in_axes=(0)
) If I set a batch too high it would crash due to OOM, which I solved with using def process_index(indices):
return jax.lax.map(vmap_get_index, indices) Now I want to expand the process to be able to run parallel in 8 GPUs. The solution I made works but takes forever to start. I wander if this is the correct way to use pmap to divide the work to each GPU? pmap_process_index = jax.pmap(
process_index,
in_axes=(0)
) Also, for input shape, I've made it so that the dimensions are (num_devices, num of sequence to map, vmap batch size, n_dim) |
Sorry, I'm not sure what your question is. Perhaps it would help to add a fully reproducible example of what you're doing, along with pointing out where the problem occurs. |
Description
What is the best way to look up a value that is stored in a BCOO sparse matrix? I tried two methods, (1) indexing directly given a known coordinate set and (2) make an another spares matrix with value of 1 in the coordinate of interest and multiply that matrix to the matrix of interest then sum to get the final value.
Getting value through indexing
Getting value through multiplication
Also just to be sure that both function returns the same value.
It looks like that getting the value by multiplication is faster than indexing. Shouldn't usual indexing be faster?
What jax/jaxlib version are you using?
jax v0.3.25, jaxlib v0.3.22
Which accelerator(s) are you using?
CPU
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: