Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel appropriately. #10180

Open
YouJiacheng opened this issue Apr 7, 2022 · 7 comments
Labels
bug Something isn't working

Comments

@YouJiacheng
Copy link
Contributor

YouJiacheng commented Apr 7, 2022

import jax
import jax.numpy as jnp

def timer(f):
    from time import time
    f() # warmup and compile
    t = time()
    for _ in range(3):
        f()
    print((time() - t) / 3)

y = jax.random.uniform(jax.random.PRNGKey(0), (16, 1024, 1024)) / 16
s = jax.block_until_ready(y @ y.transpose(0, 2, 1) + jnp.eye(1024))

from jax.lax.linalg import eigh as jeigh
f = jax.jit(jax.vmap(jeigh))
timer(lambda: jax.block_until_ready(f(s))) # 0.90s for 16 problems

from scipy.linalg import eigh as seigh
import numpy as np
ss = np.array(s[0])
timer(lambda: seigh(ss)) # 0.21s for 1 problem

GPU: V100-PCIE 16G
CPU: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz

jax.lax.linalg.eigh on 1 GPU use 0.90s for 16 problems. on all CPU-core(top report 2340% peak CPU usage) use 2.44s for 16 problems.
scipy.linalg.eigh on 1 CPU-core(top report 200% peak CPU usage) use 0.21s for 1 problem.
This result means that, GPU only have <4x throughput, and >11x CPU usage only have <1.4x throughput, while there should be a embarrassingly parallel given vmap.

@YouJiacheng YouJiacheng added the bug Something isn't working label Apr 7, 2022
@hawkinsp
Copy link
Member

hawkinsp commented Apr 7, 2022

Ultimately JAX is at the mercy of the algorithms provided by cusolver here. For small matrices (smaller than 32x32), JAX currently uses the batched Jacobi solver that Nvidia provides. For larger matrices, JAX currently iterates over the batch elements sequentially, so you should expect no speedup from vmap.

There are a number of things one could try here.

One would be to try the batched Jacobi solver at larger sizes (

if n <= 32:
and its HLO-only cousin a few lines above), see also https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-syevjbatch
Note this code is in jaxlib, although it's in the Python part of jaxlib so you can just locally edit your copy to play with it.

Another would be for jaxlib to solve multiple eigendecomposition problems in parallel on multiple CUDA streams. That would only be profitable if you aren't fully occupying GPU and CPU.

@YouJiacheng
Copy link
Contributor Author

Thanks for speedy reply! IIUC, I can change the jaxlib python code without building jaxlib by myself, and let jaxlib use batched jacobi solver for large matrices as well.

@hawkinsp
Copy link
Member

hawkinsp commented Apr 7, 2022

Yes, you could just edit the (installed) copy of cusolver.py to alter the threshold. Does it help?

@YouJiacheng
Copy link
Contributor Author

It helps! 0.90s -> 0.55s. Thank you so much! (But it is still much slower than my expectation.)
And I wonder why CPU version of jax.lax.linalg.eigh + vmap doesn't linear speedup comparing to single core scipy, it has >11x peak CPU usage.

@hawkinsp
Copy link
Member

hawkinsp commented Apr 7, 2022

You could send a PR altering the threshold, if you like, although we'd probably need to collect a wider range of timings at different sizes and batch sizes.

The CPU version also just calls a LAPACK function in a loop to handle batches. In fact, it's a LAPACK function we use provided by scipy, so I'd be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases. If we aren't getting enough parallelism, we could consider using multiple threads.

We don't have a batched eigh on CPU (as far as I am aware, no-one does on CPU, although some of the algorithms that work well when vectorized on GPU and TPU might work well on CPU also particularly for small matrix sizes, e.g., a vectorized Jacobi solver).

@YouJiacheng
Copy link
Contributor Author

In fact, it's a LAPACK function we use provided by scipy, so I'd be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases.

JAX will use multiple core, while scipy only use single core. But JAX with multiple core only has a bit speed up, at the cost of preventing user manually using spmd/data parallel.

@YouJiacheng
Copy link
Contributor Author

YouJiacheng commented Apr 7, 2022

Can we have pytorch-like set_num_threads and set_num_interop_threads to control the parallel?

import torch

torch.set_num_threads(1)
torch.set_num_interop_threads(24)

@torch.jit.script
def mt_eigh(x: torch.Tensor):
    futs = [torch.jit._fork(torch.linalg.eigh, x[i]) for i in range(24)]
    return [torch.jit._wait(fut) for fut in futs]

I find that this(7.5s for 24*1024*320*320) is 50x faster than JAX on 24-core CPU (15.6s for 1024*320*320) and 40x faster than naively let pytorch use intra-op parallelism with 24 threads(12.4s for 1024*320*320). --- which is actually 1.8x slower than single thread(6.9s for 1024*320*320), 2.3x slower than 4 threads(5.4s for 1024*320*320).

@YouJiacheng YouJiacheng changed the title jax.lax.linalg.eigh on GPU seems too slow jax.lax.linalg.eigh on GPU and multi-core CPU seems too slow Apr 7, 2022
@YouJiacheng YouJiacheng changed the title jax.lax.linalg.eigh on GPU and multi-core CPU seems too slow jax.lax.linalg.eigh on GPU and multi-core CPU doesn' parallel appropriately. Apr 7, 2022
@YouJiacheng YouJiacheng changed the title jax.lax.linalg.eigh on GPU and multi-core CPU doesn' parallel appropriately. jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel appropriately. Apr 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants