-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Running on multiple A100 gets stuck #8475
Comments
tldr: i suspect you are not setting I took a clean GCP GPU VM and installed the latest jax
i verified i have 8 GPUs
It worked
|
Hi @zhangqiaorjc , |
Can you please (a) use the current jax and jaxlib (0.1.74) and (b) provide self-contained instructions to reproduce? What precisely did you run? |
Hi ! (a) I will try this one and get back to you !
Thanks ! |
@Hatmm are you able to get (a) to work? basically use latest jax and jaxlib? |
@zhangqiaorjc (a) did not solve the problem. I have tried running on 2 GPUs (devices 0 and 2) as you can see the program looks idled. I am using cudnn 8.4.2 and cuda 11.4 with jax 0.2.25 |
After installing jaxlib 0.1.74 https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.74+cuda11.cudnn82-cp38-none-manylinux2010_x86_64.whl I get an error instead of idling |
The |
Is this issue resolved? I have found the same problem using V100s with cuda=11.4, jax=0.3.1, jaxlib=0.3.0+cuda11.cudnn82, nvidia driver 470.94. I find the program hangs (with 100% utilization) when using the inter-device communication operations (eg psum). When I avoid communication, the output is as expected but only if a regular numpy array is used as input. from functools import partial
from jax import lax, pmap, vmap
import jax.numpy as jnp
@partial(pmap, axis_name='i')
def normalize(x):
return x
print(normalize(jnp.array([10.,100.]))) Using a jnp.array returned [10,0] rather than [10, 100] as is returned by vmap or using a regular numpy array. Using a regular numpy array does not solve the hanging problem. |
@tlitfin I can't reproduce your problem. e.g., I just created a GCP VM with 4xV100, with CUDA 11.4, driver 470.103.01 and the same
The program works as expected for me and prints:
I'm wondering if something isn't working with device->device copies on your machine. Are you able to reproduce this problem on a cloud VM (i.e., a setup I could replicate exactly?) What does |
@hawkinsp I am so sorry! I started a new session this morning and found that the problem was resolved. I was already using a fresh conda environment but I must have had an environment variable set from previous debugging that made the issue persist. I am unable to re-create the faulty environment today to isolate the cause but I suspect it was a conflict between cuda/jax versions in my environment with the system-wide install. I am sorry again for wasting your time. |
Woohoo, sounds fixed! 🎉 |
As a follow up, I found that my problem returned and was not a simple environment conflict as I had suspected. The problem seems to depend on communication between specific GPUs allocated by our scheduling software. I attached a screenshot to illustrate. This may be a hardware configuration problem rather than a jax issue but I am posting the info here for completeness. |
Running the following code on one A100 GPU card works fine. However, when switching to more than one the GPUs utilization goes to 100% but their power consumption is as if they were idling.
I should add that the program cannot be manually killed anymore.
The text was updated successfully, but these errors were encountered: