Skip to content

Commit

Permalink
Move test for backend initialization into jax.distributed.initialize(…
Browse files Browse the repository at this point in the history
…) wrapper.

This allows us to skip the check for tests.

PiperOrigin-RevId: 580168674
  • Loading branch information
hawkinsp authored and jax authors committed Nov 7, 2023
1 parent c5d6df4 commit b85ea68
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def initialize(self,
process_id: Optional[int] = None,
local_device_ids: Optional[Union[int, Sequence[int]]] = None,
initialization_timeout: int = 300):
if xla_bridge.backends_are_initialized():
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
Expand Down Expand Up @@ -181,6 +178,9 @@ def initialize(coordinator_address: Optional[str] = None,
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
"""
if xla_bridge.backends_are_initialized():
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout)
atexit.register(shutdown)
Expand Down

0 comments on commit b85ea68

Please sign in to comment.