Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/docs/Dialects/Shard.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ the tensor is sharded - not specified manually.

### Device Groups

Each collective operation runs within a group of devices. You define groups
using the `grid` and `grid_axes` attributes, which describe how to slice the
full device grid into smaller groups.
Collective operations run within groups of devices, which are defined
using the `grid` and `grid_axes` attributes. These describe
how the full device grid is sliced into smaller groups.

Devices that have the same coordinates *outside* the listed `grid_axes` belong
to the same group.
Expand Down
117 changes: 66 additions & 51 deletions mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
]> {
let summary = "All-gather over a device grid.";
let description = [{
Gathers along the `gather_axis` tensor axis.
Concatenates all tensor slices from a device group defined by `grid_axes` along
the tensor dimension `gather_axis` and replicates the result across all devices
in the group.

Example:
```mlir
Expand Down Expand Up @@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device grid.";
let description = [{
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Reduces the input tensor across all devices within the groups defined by
`grid_axes`, using the specified reduction method. The operation performs an
element-wise reduction over the tensor slices from all devices in each group.
Each device in a group receives a replicated copy of the reduction result.
The accumulation element type is determined by the result type and does not
need to match the input element type. Before performing the reduction, each
input element is converted to the result element type.

Attributes:
`reduction`: Indicates the reduction method.
Expand Down Expand Up @@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
let summary = "All-slice over a device grid. This is the inverse of all-gather.";
let summary = "All-slice over a device grid.";
let description = [{
Slice along the `slice_axis` tensor axis.
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
The operation does not communicate data between devices.
Within each device group defined by `grid_axes`, slices the input tensor along
the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
the input data is replicated along the `slice_axis`.
Each process simply crops its local data to the slice corresponding to its
in-group device index.
Notice: `AllSliceOp` does not involve any communication between devices and
devices within a group may not have replicated input data.

Example:
```mlir
Expand All @@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
```
Result:
```
gather tensor
slice tensor
axis 1
------------>
+-------+-------+
Expand Down Expand Up @@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device grid.";
let description = [{
Performs an all-to-all on tensor pieces split along `split_axis`.
The resulting pieces are concatenated along `concat_axis` on ech device.
Each participant logically splits its input along split_axis,
then scatters the resulting pieces across the group defined by `grid_axes`.
After receiving data pieces from other participants' scatters,
it concatenates them along concat_axis to produce the final result.

Example:
```
Expand Down Expand Up @@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
]> {
let summary = "Broadcast over a device grid.";
let description = [{
Broadcast the tensor on `root` to all devices in each respective group.
The operation broadcasts along grid axes `grid_axes`.
The `root` device specifies the in-group multi-index that is broadcast to
all other devices in the group.
Copies the input tensor on `root` to all devices in each group defined by
`grid_axes`. The `root` device is defined by its in-group multi-index.
The contents of input tensors on non-root devices are ignored.

Example:
```
Expand All @@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
device (1, 0) -> | | | <- device (1, 1)
device (1, 0) -> | * * | * * | <- device (1, 1)
+-------+-------+
```

Expand Down Expand Up @@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
]> {
let summary = "Gather over a device grid.";
let description = [{
Gathers on device `root` along the `gather_axis` tensor axis.
`root` specifies the coordinates of a device along `grid_axes`.
It uniquely identifies the root device for each device group.
The result tensor on non-root devices is undefined.
Using it will result in undefined behavior.
Concatenates all tensor slices from a device group defined by `grid_axes` along
the tensor dimension `gather_axis` and returns the resulting tensor on each
`root` device. The result on all other (non-root) devices is undefined.
The `root` device is defined by its in-group multi-index.

Example:
```mlir
Expand Down Expand Up @@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
]> {
let summary = "Send over a device grid.";
let description = [{
Receive from a device within a device group.
Receive tensor from device `source`, which is defined by its in-group
multi-index. The groups are defined by `grid_axes`.
The content of input tensor is ignored.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
Expand All @@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
]> {
let summary = "Reduce over a device grid.";
let description = [{
Reduces on device `root` within each device group.
`root` specifies the coordinates of a device along `grid_axes`.
It uniquely identifies the root device within its device group.
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Reduces the input tensor across all devices within the groups defined by
`grid_axes`, using the specified reduction method. The operation performs an
element-wise reduction over the tensor slices from all devices in each group.
The reduction result will be returned on the `root` device of each group.
It is undefined on all other (non-root) devices.
The `root` device is defined by its in-group multi-index.
The accumulation element type is determined by the result type and does not
need to match the input element type. Before performing the reduction, each
input element is converted to the result element type.

Attributes:
`reduction`: Indicates the reduction method.
Expand Down Expand Up @@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device grid.";
let description = [{
After the reduction, the result is scattered within each device group.
The tensor is split along `scatter_axis` and the pieces distributed
across the device group.
Reduces the input tensor across all devices within the groups defined by
`grid_axes` using the specified reduction method. The reduction is performed
element-wise across the tensor pieces from all devices in the group.
After reduction, the reduction result is scattered (split and distributed)
across the device group along `scatter_axis`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
: tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
```
Expand All @@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
Result:
```
+-------+
| 6 8 | <- devices (0, 0)
| 5 6 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
| 7 8 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
| 13 14 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
| 15 16 | <- devices (1, 1)
+-------+
```
}];
Expand Down Expand Up @@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
]> {
let summary = "Scatter over a device grid.";
let description = [{
For each device group split the input tensor on the `root` device along
axis `scatter_axis` and scatter the parts across the group devices.
For each device group defined by `grid_axes`, the input tensor on the `root`
device is split along axis `scatter_axis` and distributed across the group.
The content of the input on all other (non-root) devices is ignored.
The `root` device is defined by its in-group multi-index.

Example:
```
Expand All @@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
(0, 1)
+-------+-------+ | scatter tensor
device (0, 0) -> | | | | axis 0
| | | ↓
device (0, 0) -> | * * | * * | | axis 0
| * * | * * | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
Expand Down Expand Up @@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
]> {
let summary = "Send over a device grid.";
let description = [{
Send from one device to another within a device group.
Send input tensor to device `destination`, which is defined by its in-group
multi-index. The groups are defined by `grid_axes`.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
Expand All @@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
]> {
let summary = "Shift over a device grid.";
let description = [{
Within each device group shift along grid axis `shift_axis` by an offset
`offset`.
The result on devices that do not have a corresponding source is undefined.
`shift_axis` must be one of `grid_axes`.
If the `rotate` attribute is present,
instead of a shift a rotation is done.
Within each device group defined by `grid_axes`, shifts input tensors along the
device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
That is, the offset wraps around according to the group size along `shift_axis`.
Otherwise, the results on devices without a corresponding source are undefined.

Example:
```
Expand Down