Skip to content

Commit

Permalink
minor edits to clarify pjit docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ngam committed May 9, 2022
1 parent 78fafc5 commit 7132fba
Showing 1 changed file with 32 additions and 33 deletions.
65 changes: 32 additions & 33 deletions docs/jax-101/08-pjit.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@ kernelspec:

# Introduction to pjit

This guide explains how to use pjit to compile and automatically partition functions in both single and multi-host environments. `Pjit` is `jit` with `in_axis_resources` and `out_axis_resources` to specify how a function should be partitioned. `Pjit` can be useful when the jitted version of fun would not fit in a single device’s memory, or to speed up fun by running each operation in parallel across multiple devices. `Pjit` enables users to shard computations without rewriting using the SPMD partitioner. The returned function has semantics equivalent to those of `fun`, but is compiled to an XLA computation that runs across multiple devices (e.g. multiple GPUs or multiple TPU cores).

Two examples are shown in this guide to demonstrate how pjit works in both single and multi-host environments. The identity function (`lambda x: x`) is chosen to better show how other input parameters are used.

This guide explains how to use pjit to compile and automatically partition functions in both single and multi-host environments. `pjit` is `jit` with `in_axis_resources` and `out_axis_resources` to specify how a function should be partitioned. `pjit` can be useful when the jitted version of `fun` would not fit in a single device’s memory, or to speed up `fun` by running each operation in parallel across multiple devices. `pjit` enables users to shard computations without rewriting using the SPMD partitioner. The returned function has semantics equivalent to those of `fun`, but is compiled to an XLA computation that runs across multiple devices (e.g. multiple GPUs or multiple TPU cores).

```python
jax.experimental.pjit.pjit(fun, in_axis_resources, out_axis_resources, static_argnums=(), donate_argnums=())
```

Two examples are shown in this guide to demonstrate how `pjit` works in both single and multi-host environments. The identity function (`lambda x: x`) is chosen to better show how other input parameters are used.

+++ {"id": "hATDpVCRBuXg"}

## Background
The core infrastructure that supports parallel model training is the XLA SPMD partitioner. It takes in an XLA program that represents the complete neural net, as if there is only one giant virtual device. In addition to the program, it also takes in partitioning specifications for both function inputs and outputs. The output of the XLA SPMD partitioner is an identical program for N devices that performs communications between devices through collective operations. The program only compiles once per host. **Pjit is the API exposed for the XLA SPMD partitioner in JAX.**
The core infrastructure that supports parallel model training is the XLA SPMD partitioner. It takes in an XLA program that represents the complete neural net, as if there is only one giant virtual device. In addition to the program, it also takes in partitioning specifications for both function inputs and outputs. The output of the XLA SPMD partitioner is an identical program for N devices that performs communications between devices through collective operations. The program only compiles once per host. **`pjit` is the API exposed for the XLA SPMD partitioner in JAX.**

+++ {"id": "qF01HxLwC-s0"}

Expand All @@ -38,7 +37,7 @@ The core infrastructure that supports parallel model training is the XLA SPMD pa
## How it works:
The partitioning over devices happens automatically based on the propagation of the input partitioning specified in `in_axis_resources` and the output partitioning specified in `out_axis_resources`. The resources specified in those two arguments must refer to `mesh axes`, as defined by the `jax.experimental.maps.Mesh()` context manager. Note that the `Mesh` definition at `pjit` application time is ignored, and the returned function will use the `Mesh` definition available at each call site.

Inputs to a pjit’d function will be automatically partitioned across devices if they’re not already correctly partitioned based on `in_axis_resources`. In some scenarios, ensuring that the inputs are already correctly pre-partitioned can increase performance. For example, if passing the output of one pjit’d function to another pjit’d function (or the same pjit’d function in a loop), make sure the relevant `out_axis_resources` match the corresponding `in_axis_resources`.
Inputs to a `pjit`ted function will be automatically partitioned across devices if they’re not already correctly partitioned based on `in_axis_resources`. In some scenarios, ensuring that the inputs are already correctly pre-partitioned can increase performance. For example, if passing the output of one pjit’d function to another `pjit`ted function (or the same pjit’d function in a loop), make sure the relevant `out_axis_resources` match the corresponding `in_axis_resources`.

+++ {"id": "iZNcsMjoHRtj"}

Expand All @@ -47,13 +46,13 @@ Inputs to a pjit’d function will be automatically partitioned across devices i
+++ {"id": "ej2PCAVGHcIU"}

In this example, we have
- `mesh`: a mesh of shape (4, 2) and axes named ‘x’ and ‘y’ respectively.
- `mesh`: a mesh of shape (4, 2) and axes named 'x' and 'y' respectively.
- `input data`: an 8 by 2 numpy array.
- `in_axis_resources`: None. So the (8, 2) input data is replicated across all devices.
- `out_axis_resources`: PartitionSpec('x', 'y'). It specifies that the two dimensions of output data are sharded over `x` and `y` respectively.
- `out_axis_resources`: `PartitionSpec('x', 'y')`. It specifies that the two dimensions of output data are sharded over 'x' and 'y' respectively.
- `function`: `lambda x: x`. It is the identity function.

As a result, the pjit’d function applied with the given mesh replicates the input across all devices based on `in_axis_resources`, and then keeps only what each device should have based on `out_axis_resources`. It effectively shards the data on CPU across all accelerators according to the `PartitionSpec`.
As a result, the `pjit`ted function applied with the given mesh replicates the input across all devices based on `in_axis_resources`, and then keeps only what each device should have based on `out_axis_resources`. It effectively shards the data on CPU across all accelerators according to the `PartitionSpec`.

Each parameter is explained in detail below:

Expand Down Expand Up @@ -88,11 +87,11 @@ import numpy as np
+++ {"id": "quolviRCIoZG"}

### Mesh
Mesh is defined in [jax/interpreters/pxla](https://github.com/google/jax/blob/main/jax/interpreters/pxla.py#L1389), and it is a numpy array of jax devices in a multi-dimensional grid, alongside names for the axes of this mesh. It is also called the logical mesh.
The mesh is defined in [jax/interpreters/pxla](https://github.com/google/jax/blob/main/jax/interpreters/pxla.py#L1389): It is a NumPy array of JAX devices in a multi-dimensional grid, alongside names for the axes of this mesh. It is also called the logical mesh.

+++ {"id": "V3Vnmcb7Jq1L"}

In the example we are working with, the first (vertical) axis is named ‘x’ and has length 4, and the second (horizontal) axis is named ‘y’ and has length 2. If a dimension of data is sharded across an axis, each device has a slice of the size of data.shape[dim] divided by mesh_shape[axis]. If data is replicated across an axis, devices on that axis should have the same data.
In the example we are working with, the first (vertical) axis is named ‘x’ and has length 4, and the second (horizontal) axis is named ‘y’ and has length 2. If a dimension of data is sharded across an axis, each device has a slice of the size of `data.shape[dim]` divided by `mesh_shape[axis]`. If data is replicated across an axis, devices on that axis should have the same data.

```{code-cell}
---
Expand Down Expand Up @@ -128,7 +127,7 @@ For example, we can have a physical mesh of size (4, 4, 4). If the computation r
+++ {"id": "JGNV0XCJKPlN"}

### Input Data
A numpy array of size (8,2)
A NumPy array of size (8,2)

```{code-cell}
---
Expand All @@ -154,12 +153,12 @@ input_data
Pytree of structure matching that of arguments to fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.

The valid resource assignment specifications are:
- None, in which case the value will be replicated on all devices
- PartitionSpec, a tuple of length at most equal to the rank of the partitioned value. Each element can be a None, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec. More details are discussed in the section below (More information on PartitionSpec).
- `None`, in which case the value will be replicated on all devices
- `PartitionSpec`, a tuple of length at most equal to the rank of the partitioned value. Each element can be a None, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec. More details are discussed in the section below (More information on PartitionSpec).

The size of every dimension has to be a multiple of the total number of resources assigned to it.

out_axis_resources – Like in_axis_resources, but specifies resource assignment for function outputs.
`out_axis_resources` is like `in_axis_resources`, but specifies resource assignment for function outputs.

```{code-cell}
---
Expand Down Expand Up @@ -242,7 +241,7 @@ data.device_buffers

+++ {"id": "LgZbCPrrMg54"}

The result after applying the pjit’d function is a ShardedDeviceArray, and device_buffers show what data each device has.
The result after applying the `pjit`ted function is a `ShardedDeviceArray`, and `device_buffers` shows what data each device has.

+++ {"id": "qSQhEouSMjV2"}

Expand All @@ -257,15 +256,15 @@ Each color represents a different device. Since an (8,2) array is partitioned ac

### More information on PartitionSpec:

PartitionSpec is a named tuple, whose element can be a None, a mesh axis or a tuple of mesh axes. Each element describes which mesh dimension the input’s dimension is partitioned across. For example, `PartitionSpec(“x”, “y”)` is a PartitionSpec where the first dimension of data is sharded across `x` axis of the mesh, and the second dimension is sharded across `y` axis of the mesh.
`PartitionSpec` is a named tuple, whose element can be a `None`, a mesh axis or a tuple of mesh axes. Each element describes which mesh dimension the input’s dimension is partitioned across. For example, `PartitionSpec('x', 'y')` is a PartitionSpec where the first dimension of data is sharded across 'x' axis of the mesh, and the second dimension is sharded across 'y' axis of the mesh.

+++ {"id": "lciyhDKOyvqr"}

**Examples of other possible PartitionSpecs:**

Reminder: mesh is of shape (4, 2), input data is of shape (8, 2)
Reminder: mesh is of shape (4, 2), input data is of shape (8, 2).

### **- `PartitionSpec(“x”, None)`**
### **- `PartitionSpec('x', None)`**

```{code-cell}
---
Expand Down Expand Up @@ -294,13 +293,13 @@ data.device_buffers

+++ {"id": "Tr7KUJpIL026"}

the first dimension of the input is sharded across `x` axis and the other dimensions are replicated across other axes. `None` can also be omitted in the PartitionSpec. If `out_axis_resources = PartitionSpec(“x”, None)` in the example above, the result visualization will be the following:
The first dimension of the input is sharded across 'x' axis and the other dimensions are replicated across other axes. `None` can also be omitted in the PartitionSpec. If `out_axis_resources = PartitionSpec('x', None)` in the example above, the result visualization will be the following:

![spmd](../_static/partition_spec_x_none.png)

+++ {"id": "EwN6tPSLy9Tz"}

### **- `PartitionSpec(“y”, None)`**
### **- `PartitionSpec('y', None)`**

```{code-cell}
---
Expand Down Expand Up @@ -364,7 +363,7 @@ data.device_buffers

+++ {"id": "N0WBRBOmZQI2"}

the first dimension of the input is sharded across both `x` and `y` axis and the other dimensions are replicated across other axes. We can think of this as stretching the 2D mesh into an 1D mesh and then do the partition. If `out_axis_resources = PartitionSpec((“x”, “y”), None)` in the example above, the result visualization will be the following:
the first dimension of the input is sharded across both 'x' and 'y' axes and the other dimensions are replicated across other axes. We can think of this as stretching the 2D mesh into an 1D mesh and then do the partition. If `out_axis_resources = PartitionSpec(('x', 'y'), None)` in the example above, the result visualization will be the following:

![spmd](../_static/partition_spec_xy.png)

Expand Down Expand Up @@ -399,7 +398,7 @@ data.device_buffers

+++ {"id": "siqNb-an3eUd"}

the second dimension of the input is sharded over y axis and the first dimensions is replicated across other axes. We can think of this as stretching the 2D mesh into an 1D mesh and then do the partition. If out_axis_resources = PartitionSpec(None, 'y') in the example above, the result visualization will be the following:
the second dimension of the input is sharded over y axis and the first dimensions is replicated across other axes. We can think of this as stretching the 2D mesh into an 1D mesh and then do the partition. If `out_axis_resources = PartitionSpec(None, 'y')` in the example above, the result visualization will be the following:

![spmd](../_static/partition_spec_none_y.png)

Expand All @@ -409,7 +408,7 @@ the second dimension of the input is sharded over y axis and the first dimension

+++ {"id": "NjPu86IE34Xk"}

`pjit` will complain when `out_axis_resources` is set to be `PartitionSpec(None, 'x')`. This is because the second dimension of input data is of size 2, but mesh's `x` dimension has size 4. size 2 can not be sharded over size 4. It is important to note that the size of input data has to be divisible by the size of mesh on corresponding dimensions.
`pjit` will complain when `out_axis_resources` is set to be `PartitionSpec(None, 'x')`. This is because the second dimension of input data is of size 2, but mesh's 'x' dimension has size 4. size 2 can not be sharded over size 4. It is important to note that the size of input data has to be divisible by the size of mesh on corresponding dimensions.

+++ {"id": "0iQTcTM-aCkK"}

Expand All @@ -423,11 +422,11 @@ In this example, we have

- `mesh`: a mesh of shape (16, 2) and axes named ‘x’ and ‘y’ respectively.
- `input data`: Each host contains a quarter (8, 2) of the input data of size (32, 2).
- `in_axis_resources`: PartitionSpec(('x', 'y'),). This lets `pjit` know that the (32, 2) input data is already split evenly across hosts (done by user).
- `out_axis_resources`: PartitionSpec('x', 'y'). It specifies that the two dimensions of output data are sharded over `x` and `y` respectively.
- `in_axis_resources`: `PartitionSpec(('x', 'y'),)`. This lets `pjit` know that the (32, 2) input data is already split evenly across hosts (done by user).
- `out_axis_resources`: `PartitionSpec('x', 'y')`. It specifies that the two dimensions of output data are sharded over `x` and `y` respectively.
- `function`: `lambda x: x`. It is the identity function.

The pjit’d function applied with a given mesh distributes an even slice to each device. It effectively shards the data on hosts across all accelerators based on the PartitionSpec.
The `pjit`ted function applied with a given mesh distributes an even slice to each device. It effectively shards the data on hosts across all accelerators based on the `PartitionSpec`.

+++ {"id": "pinw32PtcD3W"}

Expand Down Expand Up @@ -460,7 +459,7 @@ mesh = maps.Mesh(devices, ('x', 'y'))

In a multi-host environment, all the devices connected to one host have to contain a subslice of a single continuous large slice of the data. In JAX SPMD, there are no direct communications between hosts, so hosts only talk to each other via collective communication between devices. As a result, users need to handle distributed data loading on hosts.

In this example, the input array of size (32,2) is manually split into quarters of size (8,2) along the `x` axis by user and assigned to each host.
In this example, the input array of size (32,2) is manually split into quarters of size (8,2) along the 'x' axis by user and assigned to each host.

```python
if jax.process_index() == 0:
Expand All @@ -473,21 +472,21 @@ else:
input_data = np.arange(48,64).reshape(8,2)
```

Pjit always assumes that the input is the local data chunk of a global array. If the local chunk it to be sharded over multiple local devices and is not partitioned as expected, pjit will put the right slices on the right **local devices** for you. Once all of the local chunks are on the devices on all the
`pjit` always assumes that the input is the local data chunk of a global array. If the local chunk it to be sharded over multiple local devices and is not partitioned as expected, pjit will put the right slices on the right **local devices** for you. Once all of the local chunks are on the devices on all the
hosts, then XLA will run the computation.

XLA operates on the global data so if `in_axis_resources` is different than `out_axis_resources` then XLA will do data redistribution cross-host. So global communication doesn't happen in preparation for the launch of the XLA executable that pjit represents, but only in the XLA executable itself.

One way to do data redistribution cross-host is to use an identity pjit with `in_axis_resources` different from `out_axis_resources`. XLA will do the global data reordering for you via pjit.
One way to do data redistribution cross-host is to use an identity `pjit` with `in_axis_resources` different from `out_axis_resources`. XLA will do the global data reordering for you via `pjit`.

+++ {"id": "gTS_bgtkdch1"}

### in_axis_resources & out_axis_resources

- `in_axis_resources`: PartitionSpec(('x', 'y'),). This partitions the first dimension of input data over both `x` and `y` axes. Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
- `out_axis_resources`: PartitionSpec('x', 'y'). It specifies that the two dimensions of output data are sharded over `x` and `y` respectively, so each device gets a (2,1) slice.
- `in_axis_resources`: `PartitionSpec(('x', 'y'),)`. This partitions the first dimension of input data over both 'x' and 'y' axes. Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
- `out_axis_resources`: `PartitionSpec('x', 'y')`. It specifies that the two dimensions of output data are sharded over 'x' and 'y' respectively, so each device gets a (2,1) slice.

**Note**: in_axis_resources and out_axis_resources are different. Here, in_axis_resources shards input data's first dimension over both `x` and `y`, whereas out_axis_resources shards input data's first dimension only over `x`.
**Note**: in_axis_resources and out_axis_resources are different. Here, in_axis_resources shards input data's first dimension over both 'x' and 'y', whereas out_axis_resources shards input data's first dimension only over 'x'.

+++ {"id": "8b4kf6GtgPgD"}

Expand Down

0 comments on commit 7132fba

Please sign in to comment.