-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory complexity issue with pmap #8585
Comments
I guess the question is, is XLA's |
I guess it's not Is there any suggestion on how I can avoid memory problems while computing
Well again I feel like this is an issue, but not the issue that I mentioned in the first post. This is a problem with huge matrix multiplication in Jax now. |
Well the work-around is this: (At least this is what comes to my mind:) Z1 = lax.map(lambda X_i: np.einsum('j,jk->k', X_i, Y), X)
Z2 = X @ Y
np.alltrue(Z1 == Z2)
# True This works fine but shouldn't this be automated if Edit: Just realized that even this wouldn't work! Jax will autocompile this to |
Moreover, I noticed that calling |
Suggestion:
|
It's impossible for us to debug your problem without a complete, self-contained Python code that reproduces your problem. I don't know what is happening in place and what is happening out-of-place without debugging it, and I can't do that without a way to run the code. I note that JAX does have batched matrix multiplication operator, and |
@mohamad-amin , did you ever find a solution to this? We're hitting a similar OOM issue from @hawkinsp , I have a colab I can share: https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj |
@RylanSchaeffer Not yet, I solved my problem in another way though. I was also using |
Hey!
I'm trying to compute the result of multiple kernel ridge regressions in a parallel mode. I've wrote the code and created jax expressions of my functions using
jax.make_jaxpr
. According to the jax expressions, the data and computation should fit into my GPU (I'm using 4 V100 GPU with 16GB of RAM on each, which amounts to 64GB of GPU RAM), and they should be very far from the actual limits of what I have, but surprisingly, it throws and OOM. (I'm using 64bit precision)Basically, what I expect from the jax expressions is that the most expensive item here (memory-wise) should be the 4000 x 2000 x 10 x 10 along with the 20000x20000 matrix that are broadcasted on each GPU, which amounts to ~9GB of GPU RAM, but other than that, I can't see why this code can't fit in the GPU. (P.S: before entering the
pmap
, the gpu is in the state that is shown in the picture below)Error:
My compiled functions:
in the compiled function above, there is a
xla_call
that calls this compiled function:The text was updated successfully, but these errors were encountered: