You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The issue is happening when trying to initialize a multi-process multi-GPU environment with Slurm (but I think the problem is external to that).
Take the following simple script
importjaximportlogginglogging.getLogger().setLevel(logging.DEBUG)
jax.distributed.initialize()
ifjax.process_index() ==0:
print(jax.devices())
print(jax.device_count()) # total number of accelerator devices in the clusterprint(jax.local_device_count()) # number of accelerator devices attached to this host
and executed with srun --gres=gpu:2 --ntasks=2 --nodes=1 python main.py and it return
INFO:absl:JAX distributed initialized with visible devices: 0
INFO:absl:JAX distributed initialized with visible devices: 1
INFO:absl:Starting JAX distributed service on ainode17:4192
INFO:absl:Connecting to JAX distributed service on ainode17:4192
INFO:absl:Connecting to JAX distributed service on ainode17:4192
DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
DEBUG:absl:Initializing backend 'cuda'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
DEBUG:absl:Initializing backend 'cuda'
2022-09-27 19:23:48.425044: E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:57] INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:345) local_device->device_ordinal() == local_topology.devices_size()
*** Begin stack trace ***
PyCFunction_Call
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyObject_FastCallDict
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
PyRun_SimpleFileExFlags
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
INFO:absl:Unable to initialize backend 'cuda': INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:345) local_device->device_ordinal() == local_topology.devices_size()
DEBUG:absl:Initializing backend 'rocm'
INFO:absl:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
DEBUG:absl:Initializing backend 'tpu'
INFO:absl:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Recently (in 0.3.18) there has been an update to the interface for clusters (Slurm and TPUpods), but it doesn't look like it's due to that (i.e. manually setting coordinator_address, num_processes and process_id in distributed.initialize(...) has the same effect).
Am I doing something wrong?
What jax/jaxlib version are you using?
jax==0.3.18, jaxlib==0.3.15+cuda11.cudnn82
Which accelerator(s) are you using?
GPUs
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered:
The current jaxlib release (0.3.15) is missing support for using subsets of the CUDA devices, which the SLURM support needs. We need either to make a new jaxlib release (which we are working on already) or you can build jaxlib from source for the moment. We're working on it!
Description
The issue is happening when trying to initialize a multi-process multi-GPU environment with Slurm (but I think the problem is external to that).
Take the following simple script
and executed with
srun --gres=gpu:2 --ntasks=2 --nodes=1 python main.py
and it returnRecently (in 0.3.18) there has been an update to the interface for clusters (Slurm and TPUpods), but it doesn't look like it's due to that (i.e. manually setting
coordinator_address
,num_processes
andprocess_id
indistributed.initialize(...)
has the same effect).Am I doing something wrong?
What jax/jaxlib version are you using?
jax==0.3.18, jaxlib==0.3.15+cuda11.cudnn82
Which accelerator(s) are you using?
GPUs
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: