Skip to content

Commit

Permalink
Make the pjit docs clear about who does local and global communication
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 405421833
  • Loading branch information
yashk2810 authored and jax authors committed Oct 25, 2021
1 parent 0f47712 commit 821fcaa
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions docs/jax-101/08-pjit.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ mesh = maps.Mesh(devices, ('x', 'y'))

### Input data

In a multi-host environment, all the devices connected to one host have to contain a subslice of a single continuous large slice of the data. In JAX SPMD, there are no direct communications between hosts, so hosts only talk to each other via collective communication between devices. As a result, users need to handle distributed data loading on hosts.
In a multi-host environment, all the devices connected to one host have to contain a subslice of a single continuous large slice of the data. In JAX SPMD, there are no direct communications between hosts, so hosts only talk to each other via collective communication between devices. As a result, users need to handle distributed data loading on hosts.

In this example, the input array of size (32,2) is manually split into quarters of size (8,2) along the `x` axis by user and assigned to each host.

Expand All @@ -473,11 +473,18 @@ else:
input_data = np.arange(48,64).reshape(8,2)
```

Pjit always assumes that the input is the local data chunk of a global array. If the local chunk it to be sharded over multiple local devices and is not partitioned as expected, pjit will put the right slices on the right **local devices** for you. Once all of the local chunks are on the devices on all the
hosts, then XLA will run the computation.

XLA operates on the global data so if `in_axis_resources` is different than `out_axis_resources` then XLA will do data redistribution cross-host. So global communication doesn't happen in preparation for the launch of the XLA executable that pjit represents, but only in the XLA executable itself.

One way to do data redistribution cross-host is to use an identity pjit with `in_axis_resources` different from `out_axis_resources`. XLA will do the global data reordering for you via pjit.

+++ {"id": "gTS_bgtkdch1"}

### in_axis_resources & out_axis_resources

- `in_axis_resources`: PartitionSpec(('x', 'y'),). This partitions the first dimension of input data over both `x` and `y` axes. This lets Pjit know that the (32, 2) input data is already split evenly across hosts (done by user). Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
- `in_axis_resources`: PartitionSpec(('x', 'y'),). This partitions the first dimension of input data over both `x` and `y` axes. Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
- `out_axis_resources`: PartitionSpec('x', 'y'). It specifies that the two dimensions of output data are sharded over `x` and `y` respectively, so each device gets a (2,1) slice.

**Note**: in_axis_resources and out_axis_resources are different. Here, in_axis_resources shards input data's first dimension over both `x` and `y`, whereas out_axis_resources shards input data's first dimension only over `x`.
Expand Down

0 comments on commit 821fcaa

Please sign in to comment.