-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
distributed.py
164 lines (130 loc) · 6.34 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import os
import functools
from typing import Any, Optional
from absl import logging
from jax._src import cloud_tpu_init
from jax._src.config import config
from jax._src.lib import xla_bridge
from jax._src.lib import xla_extension
class State:
process_id: int = 0
service: Optional[Any] = None
client: Optional[Any] = None
preemption_sync_manager: Optional[Any] = None
def initialize(self,
coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if cloud_tpu_init.running_in_cloud_tpu_vm:
worker_endpoints = cloud_tpu_init.get_metadata(
'worker-network-endpoints').split(',')
if coordinator_address is None:
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
if num_processes is None:
num_processes = xla_bridge.process_count()
if process_id is None:
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))
if num_processes != len(worker_endpoints):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id manually.')
if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
if num_processes is None:
raise ValueError('Number of processes must be defined.')
if process_id is None:
raise ValueError('The process id of the current process must be defined.')
self.process_id = process_id
if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logging.info('Starting JAX distributed service on %s', coordinator_address)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')
# Set init_timeout to 5 min to leave time for all the processes to connect
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service,
init_timeout=300)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()
if config.jax_coordination_service:
self.initialize_preemption_sync_manager()
def shutdown(self):
if self.client:
self.client.shutdown()
self.client = None
if self.service:
self.service.shutdown()
self.service = None
if self.preemption_sync_manager:
self.preemption_sync_manager = None
def initialize_preemption_sync_manager(self):
if self.preemption_sync_manager is not None:
raise RuntimeError(
'Preemption sync manager should only be initialized once.')
self.preemption_sync_manager = (
xla_extension.create_preemption_sync_manager())
self.preemption_sync_manager.initialize(self.client)
global_state = State()
def initialize(coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
"""Initializes the JAX distributed system.
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
multi-host GPU and Cloud TPU. :func:`~jax.distributed.initialize` must be
called before performing any JAX computations.
The JAX distributed system serves a number of roles:
* it allows JAX processes to discover each other and share topology information,
* it performs health checking, ensuring that all processes shut down if any process dies, and
* it is used for distributed checkpointing.
If you are using GPU, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
If you are using TPU, all arguments are optional: if omitted, they
will be chosen automatically from the Cloud TPU metadata.
Args:
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
port does not matter, so long as the port is available on the coordinator
and all processes agree on the port.
May be ``None`` only on TPU, in which case it will be chosen automatically.
num_processes: Number of processes. May be ``None`` only on TPU, in
which case it will be chosen automatically based on the TPU slice.
process_id: The ID number of the current process. The ``process_id`` values across
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
May be ``None`` only on TPU; if ``None`` it will be chosen from the TPU slice
metadata.
Raises:
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
Example:
Suppose there are two GPU processs, and process 0 is the designated coordinator
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
following commands before anything else.
On process 0:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0) # doctest: +SKIP
On process 1:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
"""
global_state.initialize(coordinator_address, num_processes, process_id)
atexit.register(shutdown)
def shutdown():
"""Shuts down the distributed system.
Does nothing if the distributed system is not running."""
global_state.shutdown()