New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel appropriately. #10180
Comments
Ultimately JAX is at the mercy of the algorithms provided by cusolver here. For small matrices (smaller than 32x32), JAX currently uses the batched Jacobi solver that Nvidia provides. For larger matrices, JAX currently iterates over the batch elements sequentially, so you should expect no speedup from vmap. There are a number of things one could try here. One would be to try the batched Jacobi solver at larger sizes ( Line 551 in 2884215
Note this code is in jaxlib, although it's in the Python part of jaxlib so you can just locally edit your copy to play with it. Another would be for jaxlib to solve multiple eigendecomposition problems in parallel on multiple CUDA streams. That would only be profitable if you aren't fully occupying GPU and CPU. |
Thanks for speedy reply! IIUC, I can change the jaxlib python code without building jaxlib by myself, and let jaxlib use batched jacobi solver for large matrices as well. |
Yes, you could just edit the (installed) copy of |
It helps! 0.90s -> 0.55s. Thank you so much! (But it is still much slower than my expectation.) |
You could send a PR altering the threshold, if you like, although we'd probably need to collect a wider range of timings at different sizes and batch sizes. The CPU version also just calls a LAPACK function in a loop to handle batches. In fact, it's a LAPACK function we use provided by scipy, so I'd be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases. If we aren't getting enough parallelism, we could consider using multiple threads. We don't have a batched eigh on CPU (as far as I am aware, no-one does on CPU, although some of the algorithms that work well when vectorized on GPU and TPU might work well on CPU also particularly for small matrix sizes, e.g., a vectorized Jacobi solver). |
JAX will use multiple core, while scipy only use single core. But JAX with multiple core only has a bit speed up, at the cost of preventing user manually using spmd/data parallel. |
Can we have pytorch-like import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(24)
@torch.jit.script
def mt_eigh(x: torch.Tensor):
futs = [torch.jit._fork(torch.linalg.eigh, x[i]) for i in range(24)]
return [torch.jit._wait(fut) for fut in futs] I find that this(7.5s for |
GPU: V100-PCIE 16G
CPU: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
jax.lax.linalg.eigh
on 1 GPU use 0.90s for 16 problems. on all CPU-core(top
report 2340% peak CPU usage) use 2.44s for 16 problems.scipy.linalg.eigh
on 1 CPU-core(top
report 200% peak CPU usage) use 0.21s for 1 problem.This result means that, GPU only have <4x throughput, and >11x CPU usage only have <1.4x throughput, while there should be a embarrassingly parallel given
vmap
.The text was updated successfully, but these errors were encountered: