Skip to content

Commit

Permalink
Unbreak jax.distributed initialization.
Browse files Browse the repository at this point in the history
A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process.

Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global.

[XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads.

PiperOrigin-RevId: 453480351
  • Loading branch information
hawkinsp authored and jax authors committed Jun 7, 2022
1 parent b6f5dff commit 3e699dd
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 58 deletions.
135 changes: 80 additions & 55 deletions jax/_src/distributed.py
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import os
import functools

from typing import Optional
from typing import Any, Optional

from absl import logging
from jax._src import cloud_tpu_init
Expand All @@ -24,8 +25,71 @@
from jax._src.lib import xla_client
from jax._src.lib import xla_extension

jax_service = None
distributed_client = None
class State:
service: Optional[Any] = None
client: Optional[Any] = None

def initialize(self,
coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS',
None) or coordinator_address

if cloud_tpu_init.running_in_cloud_tpu_vm:
worker_endpoints = cloud_tpu_init.get_metadata(
'worker-network-endpoints').split(',')
if coordinator_address is None:
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
if num_processes is None:
num_processes = xla_bridge.process_count()
if process_id is None:
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))

if num_processes != len(worker_endpoints):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id manually.')

if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
if num_processes is None:
raise ValueError('Number of processes must be defined.')
if process_id is None:
raise ValueError('The process id of the current process must be defined.')

if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logging.info('Starting JAX distributed service on %s', coordinator_address)
if xla_client._version >= 72:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
else:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)

if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')

if xla_client._version >= 72:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service)
else:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()

def shutdown(self):
if self.client:
self.client.shutdown()
self.client = None
if self.service:
self.service.shutdown()
self.service = None

global_state = State()


def initialize(coordinator_address: Optional[str] = None,
Expand Down Expand Up @@ -71,71 +135,32 @@ def initialize(coordinator_address: Optional[str] = None,
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
"""

coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS',
None) or coordinator_address

if cloud_tpu_init.running_in_cloud_tpu_vm:
worker_endpoints = cloud_tpu_init.get_metadata(
'worker-network-endpoints').split(',')
if coordinator_address is None:
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
if num_processes is None:
num_processes = xla_bridge.process_count()
if process_id is None:
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))

if num_processes != len(worker_endpoints):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id manually.')

if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
if num_processes is None:
raise ValueError('Number of processes must be defined.')
if process_id is None:
raise ValueError('The process id of the current process must be defined.')

if process_id == 0:
global jax_service
if jax_service is not None:
raise RuntimeError('distributed.initialize should only be called once.')

global distributed_client
if distributed_client is not None:
raise RuntimeError('distributed.initialize should only be called once.')

logging.info('Starting JAX distributed service on %s', coordinator_address)
if xla_client._version >= 72:
jax_service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
distributed_client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service)
else:
jax_service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)
distributed_client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
distributed_client.connect()

global_state.initialize(coordinator_address, num_processes, process_id)
atexit.register(shutdown)
if xla_client._version >= 65:
factory = functools.partial(
xla_client.make_gpu_client,
distributed_client,
global_state.client,
process_id,
platform_name='cuda')
xla_bridge.register_backend_factory('cuda', factory, priority=300)
factory = functools.partial(
xla_client.make_gpu_client,
distributed_client,
global_state.client,
process_id,
platform_name='rocm')
xla_bridge.register_backend_factory('rocm', factory, priority=300)
else:
factory = functools.partial(
xla_client.make_gpu_client,
distributed_client,
global_state.client,
process_id)
xla_bridge.register_backend_factory('gpu', factory, priority=300)



def shutdown():
"""Shuts down the distributed system.
Does nothing if the distributed system is not running."""
global_state.shutdown()
2 changes: 1 addition & 1 deletion jax/distributed.py
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from jax._src.distributed import initialize
from jax._src.distributed import (initialize, shutdown)
4 changes: 2 additions & 2 deletions jax/experimental/gda_serialization/serialization.py
Expand Up @@ -224,11 +224,11 @@ def __init__(self, timeout_secs=300):
self._thread = None
self._exception = None

if distributed.distributed_client is None:
if distributed.global_state.client is None:
raise ValueError('Please initialize the distributed system via '
'`jax.distributed.initialize()` at the start of your '
'program.')
self._client = distributed.distributed_client
self._client = distributed.global_state.client
self._final_ckpt_dir = None

def __del__(self):
Expand Down
69 changes: 69 additions & 0 deletions tests/distributed_test.py
@@ -0,0 +1,69 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import unittest

from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax.config import config
import jax._src.distributed as distributed
import jax._src.lib
from jax._src import test_util as jtu

try:
import portpicker
except ImportError:
portpicker = None

config.parse_flags_with_absl()


@unittest.skipIf(jax._src.lib.xla_extension_version < 73,
"Test requires jaxlib 0.3.12 or newer.")
@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):

def testInitializeAndShutdown(self):
# Tests the public APIs. Since they use global state, we cannot use
# concurrency to simulate multiple tasks.
port = portpicker.pick_unused_port()
jax.distributed.initialize(coordinator_address=f"localhost:{port}",
num_processes=1,
process_id=0)
jax.distributed.shutdown()


@parameterized.parameters([1, 2, 4])
def testConcurrentInitializeAndShutdown(self, n):
port = portpicker.pick_unused_port()
def task(i):
# We can't call the public APIs directly because they use global state.
state = distributed.State()
state.initialize(coordinator_address=f"localhost:{port}",
num_processes=n,
process_id=i)
state.shutdown()

threads = [threading.Thread(target=task, args=(i,)) for i in range(n)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 3e699dd

Please sign in to comment.