diff --git a/README.md b/README.md index da3572c3ec91..4e0ca7f6efac 100644 --- a/README.md +++ b/README.md @@ -432,9 +432,6 @@ operating system, CUDA, and CuDNN are possible, but require [building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). * CUDA 11.1 or newer is *required*. - * You may be able to use older CUDA versions if you [build from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source), - but there are known bugs in CUDA in all CUDA versions older than 11.1, so we - do not ship prebuilt binaries for older CUDA versions. * The supported cuDNN versions for the prebuilt wheels are: * cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality. diff --git a/docs/jax.distributed.rst b/docs/jax.distributed.rst index 92718cee0df7..b3024e4cca86 100644 --- a/docs/jax.distributed.rst +++ b/docs/jax.distributed.rst @@ -8,4 +8,5 @@ jax.distributed module .. autosummary:: :toctree: _autosummary - initialize \ No newline at end of file + initialize + shutdown \ No newline at end of file diff --git a/docs/multi_process.md b/docs/multi_process.md index 692f8bbe6c96..b838cc0c4764 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -2,72 +2,114 @@ ## Introduction -This guide explains how to use JAX in environments such as [Cloud -TPU](https://cloud.google.com/tpu) pods where accelerators are spread across -multiple CPU hosts or JAX processes. We’ll refer to these as “multi-process” -environments. +This guide explains how to use JAX in environments such as +GPU clusters and [Cloud TPU](https://cloud.google.com/tpu) pods where +accelerators are spread across multiple CPU hosts or JAX processes. We’ll refer +to these as “multi-process” environments. This guide specifically focuses on how to use collective communication -operations (e.g. {func}`jax.lax.psum`) in multi-process settings, although other -communication methods may be useful too depending on your use case (e.g. RPC, -[mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already familiar -with JAX’s collective operations, we recommend starting with the -{doc}`/jax-101/06-parallelism` notebook. An important requirement of multi-process -environments in JAX is direct communication links between accelerators, e.g. the -high-speed interconnects for Cloud TPUs or -[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links are what allow -collective operations to run across multiple processes’ worth of accelerators. - +operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although +other communication methods may be useful too depending on your use case (e.g. +RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already +familiar with JAX’s collective operations, we recommend starting with the +{doc}`/jax-101/06-parallelism` notebook. An important requirement of +multi-process environments in JAX is direct communication links between +accelerators, e.g. the high-speed interconnects for Cloud TPUs or +[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow +collective operations to run across multiple processes’ worth of accelerators +with high performance. ## Multi-process programming model Key concepts: -* You must run at least one JAX process per host. -* Each process has a distinct set of _local_ devices it can address. The - _global_ devices are the set of all devices across all processes. -* Use standard JAX parallelism APIs like {func}`~jax.pmap` and - {func}`~jax.experimental.maps.xmap`. Each process “sees” _local_ input and + + * You must run at least one JAX process per host. + * You should initialize the cluster with {func}`jax.distributed.initialize`. + * Each process has a + distinct set of *local* devices it can address. The *global* devices are the set + of all devices across all processes. + * Use standard JAX parallelism APIs like {func}`~jax.pmap` and + {func}`~jax.experimental.maps.xmap` . Each process “sees” *local* input and output to parallelized functions, but communication inside the computations - is _global_. -* Make sure all processes run the same parallel computations in the same + is *global*. + * Make sure all processes run the same parallel computations in the same order. - ### Launching JAX processes Unlike other distributed systems where a single controller node manages many worker nodes, JAX uses a “multi-controller” programming model where each JAX -Python process runs independently, sometimes referred to as a -{term}`Single Program, Multiple Data (SPMD)` model. Generally, the same -JAX Python program is run in each process, with only slight differences between -each process’s execution (e.g. different processes will load different input -data). Furthermore, **you must manually run your JAX program on each host!** JAX +Python process runs independently, sometimes referred to as a {term}`Single +Program, Multiple Data (SPMD)` model. Generally, the same JAX Python +program is run in each process, with only slight differences between each +process’s execution (e.g. different processes will load different input data). +Furthermore, **you must manually run your JAX program on each host!** JAX doesn’t automatically start multiple processes from a single program invocation. -(This is why this guide isn’t offered as a notebook -- we don’t currently have a -good way to manage multiple Python processes from a single notebook.) +(The requirement for multiple processes is why this guide isn’t offered as a +notebook -- we don’t currently have a good way to manage multiple Python +processes from a single notebook.) + +### Initializing the cluster + +To initialize the cluster, you should call {func}`jax.distributed.initialize` at +the start of each process. {func}`jax.distributed.initialize` must be called +early in the program, before any JAX computations are executed. + +The API {func}`jax.distributed.initialize` takes several arguments, namely: + + * `coordinator_address`: the IP address of process 0 in your cluster, together + with a port available on that process. Process 0 will start a JAX service + exposed via that IP address and port, to which the other processes in the + cluster will connect. + * `num_processes`: the number of processes in the cluster + * `process_id`: the ID number of this process, in the range `[0 .. + num_processes)`. + +For example on GPU, a typical usage is: + +```python +import jax + +jax.distributed.initialize(coordinator_address="192.168.0.1:1234", + num_processes=2, + process_id=0) +``` + +On Cloud TPU, you can simply call {func}`jax.distributed.initialize()` with no +arguments. Default values for the arguments will be chosen automatically using +the TPU pod metadata: + +```python +import jax + +jax.distributed.initialize() +``` +On TPU at present calling {func}`jax.distributed.initialize` is optional, but +recommanded since it enables additional checkpointing and health checking features. ### Local vs. global devices Before we get to running multi-process computations from your program, it’s -important to understand the distinction between _local_ and _global_ devices. - -**A process’s _local_ devices are those that it can directly address and launch -computations on.** For example, in a Cloud TPU pod, each host can only launch -computations on the 8 TPU cores attached directly to that host (see the [Cloud -TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture) +important to understand the distinction between *local* and *global* devices. + +**A process’s *local* devices are those that it can directly address and launch +computations on.** For example, on a GPU cluster, each host can only launch +computations on the directly attached GPUs. On a Cloud TPU pod, each host can +only launch computations on the 8 TPU cores attached directly to that host (see +the +[Cloud TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture) documentation for more details). You can see a process’s local devices via -{func}`jax.local_devices()`. +{func}`jax.local_devices()` . -**The _global_ devices are the devices across all processes.** A computation can +**The *global* devices are the devices across all processes.** A computation can span devices across processes and perform collective operations via the direct communication links between devices, as long as each process launches the computation on its local devices. You can see all available global devices via -{func}`jax.devices()`. A process’s local devices are always a subset of the +{func}`jax.devices()` . A process’s local devices are always a subset of the global devices. - ### Running multi-process computations So how do you actually run a computation involving cross-process communication? @@ -77,22 +119,23 @@ For example, {func}`~jax.pmap` can be used to run a parallel computation across multiple processes. (If you’re not already familiar with how to use {func}`~jax.pmap` to run across multiple devices within a single process, check out the {doc}`/jax-101/06-parallelism` notebook.) Each process should call the -same pmapped function and pass in arguments to be mapped across its _local_ -devices (i.e., the pmapped axis size is equal to the number of local -devices). Similarly, the function will return outputs sharded across _local_ -devices only. Inside the function, however, collective communication operations -are run across all _global_ devices, across all processes. Conceptually, this -can be thought of as running a pmap over a single array sharded across hosts, -where each host “sees” only its local shard of the input and output. +same pmapped function and pass in arguments to be mapped across its *local* +devices (i.e., the pmapped axis size is equal to the number of local devices). +Similarly, the function will return outputs sharded across *local* devices only. +Inside the function, however, collective communication operations are run across +all *global* devices, across all processes. Conceptually, this can be thought of +as running a pmap over a single array sharded across hosts, where each host +“sees” only its local shard of the input and output. Here’s an example of multi-process pmap in action: ```python -# The following is run in parallel on each host in a Cloud TPU v3-32 pod slice +# The following is run in parallel on each host on a GPU cluster or TPU pod slice. >>> import jax ->>> jax.device_count() # total number of TPU cores in pod slice +>>> jax.distributed.initialize() # On GPU, see above for the necessary arguments. +>>> jax.device_count() # total number of accelerator devices in the cluster 32 ->>> jax.local_device_count() # number of TPU cores attached to this host +>>> jax.local_device_count() # number of accelerator devices attached to this host 8 # The psum is performed over all mapped devices across the pod slice >>> xs = jax.numpy.ones(jax.local_device_count()) @@ -102,12 +145,10 @@ ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) {func}`~jax.experimental.maps.xmap` works similarly when using a physical hardware mesh (see the {doc}`xmap tutorial` if you’re -not familiar with the single-process version). Like {func}`~jax.pmap`, the +not familiar with the single-process version). Like {func}`~jax.pmap` , the inputs and outputs are local and any parallel communication inside the xmapped function is global. The mesh is also global. -TODO: xmap example - **It’s very important that all processes run the same cross-process computations in the same order.** Running the same JAX Python program in each process is usually sufficient. Some common pitfalls to look out for that may cause diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index dd264481c4f7..8d6d04216abb 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -102,45 +102,53 @@ def initialize_preemption_sync_manager(self): def initialize(coordinator_address: Optional[str] = None, num_processes: Optional[int] = None, process_id: Optional[int] = None): - """Initialize distributed system for topology discovery. + """Initializes the JAX distributed system. - Currently, calling ``initialize`` sets up the multi-host GPU backend and Cloud - TPU backend. + Calling :func:`~jax.distributed.initialize` prepares JAX for execution on + multi-host GPU and Cloud TPU. :func:`~jax.distributed.initialize` must be + called before performing any JAX computations. - If you are on GPU platform, you will have to provide the coordinator_address - and other args to the `initialize` API. + The JAX distributed system serves a number of roles: - If you are on TPU platform, the coordinator_address and other args will be - auto detected but you have the option to provide it too. + * it allows JAX processes to discover each other and share topology information, + * it performs health checking, ensuring that all processes shut down if any process dies, and + * it is used for distributed checkpointing. + + If you are using GPU, you must provide the ``coordinator_address``, + ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. + + If you are using TPU, all arguments are optional: if omitted, they + will be chosen automatically from the Cloud TPU metadata. Args: - coordinator_address: IP address and port of the coordinator. The choice of + coordinator_address: the IP address of process `0` and a port on which that + process should launch a coordinator service. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port. - Can be None only for TPU platform. If coordinator_address is None on TPU, - then it will be auto detected. - num_processes: Number of processes. Can be None only for TPU platform and - if None will be determined from the TPU slice metadata. - process_id: Id of the current process. Can be None only for TPU platform and - if None will default to the current TPU worker id determined via the TPU - slice metadata. + May be ``None`` only on TPU, in which case it will be chosen automatically. + num_processes: Number of processes. May be ``None`` only on TPU, in + which case it will be chosen automatically based on the TPU slice. + process_id: The ID number of the current process. The ``process_id`` values across + the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``. + May be ``None`` only on TPU; if ``None`` it will be chosen from the TPU slice + metadata. Raises: - RuntimeError: If `distributed.initialize` is called more than once. + RuntimeError: If :func:`~jax.distributed.initialize` is called more than once. Example: - Suppose there are two GPU hosts, and host 0 is the designated coordinator + Suppose there are two GPU processs, and process 0 is the designated coordinator with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the following commands before anything else. - On host 0: + On process 0: - >>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP + >>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0) # doctest: +SKIP - On host 1: + On process 1: - >>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP + >>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP """ global_state.initialize(coordinator_address, num_processes, process_id) atexit.register(shutdown)