Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to initialize CUDA backend on multi-process distributed environments #12533

Closed
srossi93 opened this issue Sep 27, 2022 · 5 comments
Closed
Labels
bug Something isn't working

Comments

@srossi93
Copy link

Description

The issue is happening when trying to initialize a multi-process multi-GPU environment with Slurm (but I think the problem is external to that).

Take the following simple script

import jax
import logging

logging.getLogger().setLevel(logging.DEBUG)

jax.distributed.initialize()

if jax.process_index() == 0:
  print(jax.devices())
  print(jax.device_count())    # total number of accelerator devices in the cluster
  print(jax.local_device_count())    # number of accelerator devices attached to this host

and executed with srun --gres=gpu:2 --ntasks=2 --nodes=1 python main.py and it return

INFO:absl:JAX distributed initialized with visible devices: 0
INFO:absl:JAX distributed initialized with visible devices: 1
INFO:absl:Starting JAX distributed service on ainode17:4192
INFO:absl:Connecting to JAX distributed service on ainode17:4192
INFO:absl:Connecting to JAX distributed service on ainode17:4192
DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
DEBUG:absl:Initializing backend 'cuda'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
DEBUG:absl:Initializing backend 'cuda'
2022-09-27 19:23:48.425044: E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:57] INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:345) local_device->device_ordinal() == local_topology.devices_size() 
*** Begin stack trace ***
	
	
	
	
	
	
	
	PyCFunction_Call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyObject_FastCallDict
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyObject_Call
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCode
	
	
	
	PyRun_SimpleFileExFlags
	Py_RunMain
	Py_BytesMain
	__libc_start_main
	_start
*** End stack trace ***

INFO:absl:Unable to initialize backend 'cuda': INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:345) local_device->device_ordinal() == local_topology.devices_size() 
DEBUG:absl:Initializing backend 'rocm'
INFO:absl:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
DEBUG:absl:Initializing backend 'tpu'
INFO:absl:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Recently (in 0.3.18) there has been an update to the interface for clusters (Slurm and TPUpods), but it doesn't look like it's due to that (i.e. manually setting coordinator_address, num_processes and process_id in distributed.initialize(...) has the same effect).

Am I doing something wrong?

What jax/jaxlib version are you using?

jax==0.3.18, jaxlib==0.3.15+cuda11.cudnn82

Which accelerator(s) are you using?

GPUs

Additional system info

No response

NVIDIA GPU info

No response

@srossi93 srossi93 added the bug Something isn't working label Sep 27, 2022
@hawkinsp
Copy link
Collaborator

The current jaxlib release (0.3.15) is missing support for using subsets of the CUDA devices, which the SLURM support needs. We need either to make a new jaxlib release (which we are working on already) or you can build jaxlib from source for the moment. We're working on it!

@srossi93
Copy link
Author

Thanks for the info. Is there a timeline for the next release of jaxlib?

@mjsML
Copy link
Collaborator

mjsML commented Sep 27, 2022

Should be fixed IIUC.

@hawkinsp
Copy link
Collaborator

This should be fixed with jax and jaxlib 0.3.20, which we just released. Please try it out!

@srossi93
Copy link
Author

Thanks for the update. I was compiling the new version, but I guess tomorrow I’ll try now with the precompiled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants