Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running on multiple A100 gets stuck #8475

Closed
Hatmm opened this issue Nov 6, 2021 · 13 comments
Closed

Running on multiple A100 gets stuck #8475

Hatmm opened this issue Nov 6, 2021 · 13 comments
Assignees
Labels
bug Something isn't working

Comments

@Hatmm
Copy link

Hatmm commented Nov 6, 2021

Running the following code on one A100 GPU card works fine. However, when switching to more than one the GPUs utilization goes to 100% but their power consumption is as if they were idling.

os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.devices
  from functools import partial
  from jax import lax, pmap
  import jax.numpy as jnp

  @partial(pmap, axis_name='i')
  def normalize(x):
        return x / lax.psum(x, 'i')

  print(normalize(jnp.arange(2.)))

image

I should add that the program cannot be manually killed anymore.

@Hatmm Hatmm added the bug Something isn't working label Nov 6, 2021
@zhangqiaorjc
Copy link
Collaborator

tldr: i suspect you are not setting CUDA_VISIBLE_DEVICES correctly. See my successful attempt below

I took a clean GCP GPU VM and installed the latest jax

pip install --upgrade pip

# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

i verified i have 8 GPUs

>>> import jax
>>> jax.devices()
[GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0), GpuDevice(id=4, process_index=0), GpuDevice(id=5, process_inde
x=0), GpuDevice(id=6, process_index=0), GpuDevice(id=7, process_index=0)]
zhangqiaorjc@skyewm-gpu-vm2:~$ cat issue_8475.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
from functools import partial
from jax import lax, pmap
import jax.numpy as jnp
@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))

It worked

zhangqiaorjc@skyewm-gpu-vm2:~$ python3 issue_8475.py 
[0.         0.16666667 0.33333334 0.5   

@zhangqiaorjc zhangqiaorjc self-assigned this Nov 17, 2021
@Hatmm
Copy link
Author

Hatmm commented Nov 20, 2021

Hi @zhangqiaorjc ,
I have tried your code on my environment ( 0.1.70+cuda110 ), I still have the same bug.
Please find attached a screenshot where I have used GPU 0,4,5,6. As you can see the GPU utilization is maximal while the power consumption is low. The program gets stuck and cannot be manually killed.

Screen Shot 2021-11-20 at 4 04 55 PM

@hawkinsp
Copy link
Collaborator

Can you please (a) use the current jax and jaxlib (0.1.74) and (b) provide self-contained instructions to reproduce? What precisely did you run?

@Hatmm
Copy link
Author

Hatmm commented Nov 21, 2021

Hi !

(a) I will try this one and get back to you !
(b) this is the code I ran by calling python3 issue_8475.py

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
from functools import partial
from jax import lax, pmap
import jax.numpy as jnp
@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))

Thanks !

@zhangqiaorjc
Copy link
Collaborator

@Hatmm are you able to get (a) to work? basically use latest jax and jaxlib?

@Hatmm
Copy link
Author

Hatmm commented Nov 30, 2021

@zhangqiaorjc (a) did not solve the problem. I have tried running on 2 GPUs (devices 0 and 2) as you can see the program looks idled. I am using cudnn 8.4.2 and cuda 11.4 with jax 0.2.25
Screen Shot 2021-11-30 at 6 16 47 PM
Screen Shot 2021-11-30 at 6 16 12 PM
.

@Hatmm
Copy link
Author

Hatmm commented Nov 30, 2021

After installing jaxlib 0.1.74 https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.74+cuda11.cudnn82-cp38-none-manylinux2010_x86_64.whl

I get an error instead of idling
RuntimeError: INTERNAL: CudnnLegacyConvRunner cached across multiple StreamExecutors.: while running replica 1 and partition 0 of a replicated computation (other replicas may have failed as well).
I believe this issue is related to this one https://github.com/google/jax/issues/8654

@tomhennigan
Copy link
Collaborator

The CudnnLegacyConvRunner issue was resolved in XLA and will be available in the next release (see #8654 (comment)).

@tlitfin
Copy link

tlitfin commented Mar 7, 2022

Is this issue resolved? I have found the same problem using V100s with cuda=11.4, jax=0.3.1, jaxlib=0.3.0+cuda11.cudnn82, nvidia driver 470.94.

I find the program hangs (with 100% utilization) when using the inter-device communication operations (eg psum). When I avoid communication, the output is as expected but only if a regular numpy array is used as input.

from functools import partial
from jax import lax, pmap, vmap
import jax.numpy as jnp

@partial(pmap, axis_name='i')
def normalize(x):
  return x

print(normalize(jnp.array([10.,100.])))

Using a jnp.array returned [10,0] rather than [10, 100] as is returned by vmap or using a regular numpy array.

Using a regular numpy array does not solve the hanging problem.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Mar 7, 2022

@tlitfin I can't reproduce your problem. e.g., I just created a GCP VM with 4xV100, with CUDA 11.4, driver 470.103.01 and the same jax and jaxlib versions.

$ nvidia-smi
Mon Mar  7 13:23:10 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    32W / 300W |     80MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   37C    P0    43W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:06.0 Off |                    0 |
| N/A   37C    P0    33W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:07.0 Off |                    0 |
| N/A   36C    P0    42W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

The program works as expected for me and prints:

$ python t.py
[ 10. 100.]

I'm wondering if something isn't working with device->device copies on your machine.

Are you able to reproduce this problem on a cloud VM (i.e., a setup I could replicate exactly?)

What does nvidia-smi topo -m print?

@tlitfin
Copy link

tlitfin commented Mar 8, 2022

@hawkinsp I am so sorry! I started a new session this morning and found that the problem was resolved. I was already using a fresh conda environment but I must have had an environment variable set from previous debugging that made the issue persist. I am unable to re-create the faulty environment today to isolate the cause but I suspect it was a conflict between cuda/jax versions in my environment with the system-wide install. I am sorry again for wasting your time.

@mattjj
Copy link
Collaborator

mattjj commented Mar 8, 2022

Woohoo, sounds fixed! 🎉

@mattjj mattjj closed this as completed Mar 8, 2022
@tlitfin
Copy link

tlitfin commented Mar 15, 2022

As a follow up, I found that my problem returned and was not a simple environment conflict as I had suspected. The problem seems to depend on communication between specific GPUs allocated by our scheduling software. I attached a screenshot to illustrate.

pmap

This may be a hardware configuration problem rather than a jax issue but I am posting the info here for completeness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants