Skip to content

Commit

Permalink
Automate arguments for jax.distributed.initialize for cloud TPU envir…
Browse files Browse the repository at this point in the history
…onments.

PiperOrigin-RevId: 586892544
  • Loading branch information
jax authors committed Dec 1, 2023
1 parent a07ed22 commit 8ad774f
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list
* Changes
* The minimum jaxlib version is now 0.4.19.
* Released wheels are built now with clang instead of gcc.
* Enforce that the device backend has not been initialized prior to calling `jax.distributed.initialize()`.
* Automate arguments to `jax.distributed.initialize()` in cloud TPU environments.

* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/clusters/__init__.py
Expand Up @@ -22,4 +22,6 @@
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster
from .cloud_tpu_cluster import TpuCluster
from .cloud_tpu_cluster import GkeTpuCluster
from .cloud_tpu_cluster import MultisliceGceTpuCluster
from .cloud_tpu_cluster import SingleSliceGceTpuCluster
121 changes: 110 additions & 11 deletions jax/_src/clusters/cloud_tpu_cluster.py
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import os
import re
import socket
import time
from typing import Optional
from jax._src import xla_bridge
from jax._src import clusters
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

Expand Down Expand Up @@ -43,32 +45,129 @@ def get_metadata(key):
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text

def get_tpu_env_value(key):
def get_tpu_env_value_from_metadata(key):
tpu_env_data = get_metadata('tpu-env')
key_value_pairs = tpu_env_data.split('\n')
for key_value_pair in key_value_pairs:
# Typical line is MEGASCALE_NUM_SLICES: '2'
if ':' in key_value_pair:
row_key, value = re.split(':', key_value_pair, 1)
row_key = row_key.strip()
if row_key == key:
return value.strip().strip("'")
return None

value = os.environ.get(key, None)
return value if value is not None else get_tpu_env_value_from_metadata(key)

def is_gce_env():
worker_number_string = get_metadata('agent-worker-number')
try:
worker_number = int(worker_number_string)
return True
except:
return False

def is_multislice_gce_env():
return is_gce_env() and get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None

def is_gke_env():
return os.environ.get("TPU_WORKER_HOSTNAMES", None) is not None

class TpuCluster(clusters.ClusterEnv):
def get_gce_worker_endpoints() -> str:
return get_metadata('worker-network-endpoints').split(',')

class SingleSliceGceTpuCluster(clusters.ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm
return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env()

@classmethod
def get_coordinator_address(cls) -> str:
return cls._get_worker_endpoints()[0].split(':')[2] + ':8476'
return get_gce_worker_endpoints()[0].split(':')[2] + ':8476'

@classmethod
def get_process_count(cls) -> int:
return xla_bridge.process_count()
return len(get_gce_worker_endpoints())

@classmethod
def get_process_id(cls) -> int:
if cls.get_process_count() != len(cls._get_worker_endpoints()):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id to jax.distributed.initialize() manually.')
return int(get_metadata('agent-worker-number'))

@classmethod
def get_local_process_id(cls) -> Optional[int]:
return None

class MultisliceGceTpuCluster(clusters.ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm and is_multislice_gce_env()

@classmethod
def get_coordinator_address(cls) -> str:
coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
coordinator_address = coordinator_address.split(':')[0]

# The coordinator may not be up before the other hosts try to
# communicate with it. We check for its existence with retries.
coordinator_found = False
lookup_attempt = 1
max_coordinator_lookups = 50
while not coordinator_found and lookup_attempt <= max_coordinator_lookups:
try:
ip_address = socket.gethostbyname(coordinator_address)
coordinator_found = True
except socket.gaierror:
print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...")
lookup_attempt += 1
time.sleep(5)

if not coordinator_found:
raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}")

# Use a different port for the jax coordinator than the MXLA coordinator,
# which is set to 8080 in multislice GCE.
coordinator_address = coordinator_address + ':8476'
return coordinator_address

@classmethod
def get_process_count(cls) -> int:
processes_per_slice = cls._get_process_count_per_slice()
num_slices = int(get_tpu_env_value('MEGASCALE_NUM_SLICES'))
return processes_per_slice * num_slices

@classmethod
def get_process_id(cls) -> int:
process_id_in_slice = cls._get_process_id_in_slice()
slice_id = int(get_tpu_env_value('MEGASCALE_SLICE_ID'))
processes_per_slice = cls._get_process_count_per_slice()
return process_id_in_slice + slice_id * processes_per_slice

@classmethod
def get_local_process_id(cls) -> Optional[int]:
return None

@staticmethod
def _get_process_count_per_slice() -> int:
return len(get_gce_worker_endpoints())

@staticmethod
def _get_process_id_in_slice() -> int:
return int(get_metadata('agent-worker-number'))

class GkeTpuCluster(MultisliceGceTpuCluster):
# This class handles both single and multislice GKE as the environment
# variables are set the same in both cases.
@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm and is_gke_env()

@staticmethod
def _get_process_count_per_slice() -> int:
tpu_worker_hostnames = str(os.environ.get('TPU_WORKER_HOSTNAMES', None))
return len(tpu_worker_hostnames.split(','))

@staticmethod
def _get_worker_endpoints() -> str:
return get_metadata('worker-network-endpoints').split(',')
def _get_process_id_in_slice() -> int:
return int(str(os.environ.get('TPU_WORKER_ID')))

0 comments on commit 8ad774f

Please sign in to comment.