Skip to content

Commit

Permalink
Add jax.distributed.initialize for multi-host GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqiaorjc committed Oct 26, 2021
1 parent 821fcaa commit 0be30fb
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...main).

* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
Expand Down
1 change: 1 addition & 0 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
# jax and rely on the names imported above.
from . import abstract_arrays as abstract_arrays
from . import api_util as api_util
from . import distributed as distributed
from . import dtypes as dtypes
from . import errors as errors
from . import image as image
Expand Down
59 changes: 59 additions & 0 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools

from absl import logging
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension

_service = None
def initialize(coordinator_address: str, num_processes: int, process_id: int):
"""Initialize distributed system for topology discovery.
Currently, calling ``initialize`` sets up the multi-host GPU backend, and
is not required for CPU or TPU backends.
Args:
coordinator_address: IP address of the coordinator.
num_processes: Number of processes.
process_id: Id of the current processe.
Example:
Suppose there are two GPU hosts, and host 0 is the designated coordinator
with address '10.0.0.1:1234', to initialize the GPU cluster, run the
following commands before anything else.
On host 0
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP
On host 1
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
"""
if process_id == 0:
global _service
assert _service is None, 'initialize should be called once only'
logging.info('Starting JAX distributed service on %s', coordinator_address)
_service = xla_extension.get_distributed_runtime_service(coordinator_address,
num_processes)

client = xla_extension.get_distributed_runtime_client(coordinator_address,
process_id)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
client.connect()

factory = functools.partial(xla_client.make_gpu_client, client, process_id)
xla_bridge.register_backend_factory('gpu', factory, priority=300)
12 changes: 7 additions & 5 deletions jax/_src/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,15 @@ def _log_warning():
# example, there could be multiple backends that provide the same kind of
# device.
_backend_factories = {}
_default_backend = None
_backends : Dict[str, Any] = {}
_backends_errors : Dict[str, str] = {}
_backend_lock = threading.Lock()

def register_backend_factory(name, factory, *, priority=0):
with _backend_lock:
if name in _backends:
raise RuntimeError(f"Backend {name} already initialized")
_backend_factories[name] = (factory, priority)


Expand All @@ -187,11 +194,6 @@ def register_backend_factory(name, factory, *, priority=0):
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)

_default_backend = None
_backends : Dict[str, Any] = {}
_backends_errors : Dict[str, str] = {}
_backend_lock = threading.Lock()


def backends():
global _backends
Expand Down
16 changes: 16 additions & 0 deletions jax/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.distributed import initialize

0 comments on commit 0be30fb

Please sign in to comment.