From 0a1c67afdaa22eb9805997ee51e0f019f6614692 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Thu, 16 Oct 2025 15:36:27 +0200 Subject: [PATCH 1/2] improving shard docs --- mlir/docs/Dialects/Shard.md | 6 +- .../include/mlir/Dialect/Shard/IR/ShardOps.td | 117 ++++++++++-------- 2 files changed, 69 insertions(+), 54 deletions(-) diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md index eb6ff6150e474..573e888e6541f 100644 --- a/mlir/docs/Dialects/Shard.md +++ b/mlir/docs/Dialects/Shard.md @@ -27,9 +27,9 @@ the tensor is sharded - not specified manually. ### Device Groups -Each collective operation runs within a group of devices. You define groups -using the `grid` and `grid_axes` attributes, which describe how to slice the -full device grid into smaller groups. +Collective operations run within groups of devices, which are defined +using the `grid` and `grid_axes` attributes. These describe +how the full device grid is sliced into smaller groups. Devices that have the same coordinates *outside* the listed `grid_axes` belong to the same group. diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index b9d7163ea4c1e..60461b9ddc826 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [ ]> { let summary = "All-gather over a device grid."; let description = [{ - Gathers along the `gather_axis` tensor axis. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and replicates the result across all devices + in the group. Example: ```mlir @@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [ SameOperandsAndResultShape]> { let summary = "All-reduce over a device grid."; let description = [{ - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + Each device in a group receives a replicated copy of the reduction result. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ SameOperandsAndResultElementType, SameOperandsAndResultRank ]> { - let summary = "All-slice over a device grid. This is the inverse of all-gather."; + let summary = "All-slice over a device grid."; let description = [{ - Slice along the `slice_axis` tensor axis. - This operation can be thought of as the inverse of all-gather. - Technically, it is not required that all processes have the same input tensor. - Each process will slice a piece of its local tensor based on its in-group device index. - The operation does not communicate data between devices. + Within each device group defined by `grid_axes`, slices the input tensor along + the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if + the input data is replicated along the `slice_axis`. + Each process simply crops its local data to the slice corresponding to its + in-group device index. + Notice: `AllSliceOp` does not involve any communication between devices and + devices within a group may not have replicated input data. Example: ```mlir @@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ ``` Result: ``` - gather tensor + slice tensor axis 1 ------------> +-------+-------+ @@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [ SameOperandsAndResultRank]> { let summary = "All-to-all over a device grid."; let description = [{ - Performs an all-to-all on tensor pieces split along `split_axis`. - The resulting pieces are concatenated along `concat_axis` on ech device. + Each participant logically splits its input along split_axis, + then scatters the resulting pieces across the group defined by `grid_axes`. + After receiving data pieces from other participants' scatters, + it concatenates them along concat_axis to produce the final result. Example: ``` @@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ ]> { let summary = "Broadcast over a device grid."; let description = [{ - Broadcast the tensor on `root` to all devices in each respective group. - The operation broadcasts along grid axes `grid_axes`. - The `root` device specifies the in-group multi-index that is broadcast to - all other devices in the group. + Copies the input tensor on `root` to the all devices in each group defined by + `grid_axes`. The `root` device is defined by its in-group multi-index. + The contents of input tensors on non-root devices are ignored. Example: ``` @@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ +-------+-------+ | broadcast device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 +-------+-------+ ↓ - device (1, 0) -> | | | <- device (1, 1) + device (1, 0) -> | * * | * * | <- device (1, 1) +-------+-------+ ``` @@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [ ]> { let summary = "Gather over a device grid."; let description = [{ - Gathers on device `root` along the `gather_axis` tensor axis. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device for each device group. - The result tensor on non-root devices is undefined. - Using it will result in undefined behavior. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and returns the resulting tensor on each + `root` device. The result on all other (non-root) devices is undefined. + The `root` device is defined by its in-group multi-index. Example: ```mlir @@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [ ]> { let summary = "Send over a device grid."; let description = [{ - Receive from a device within a device group. + Receive tensor from device `source`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. + The content of input tensor is ignored. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [ ]> { let summary = "Reduce over a device grid."; let description = [{ - Reduces on device `root` within each device group. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device within its device group. - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + The reduction result will be returned on the `root` device of each group. + It is undefined on all other (non-root) devices. + The `root` device is defined by its in-group multi-index. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" SameOperandsAndResultRank]> { let summary = "Reduce-scatter over a device grid."; let description = [{ - After the reduction, the result is scattered within each device group. - The tensor is split along `scatter_axis` and the pieces distributed - across the device group. + Reduces the input tensor across all devices within the groups defined by + `grid_axes` using the specified reduction method. The reduction is performed + element-wise across the tensor pieces from all devices in the group. + After reduction, the reduction result is scattered (split and distributed) + across the device group along `scatter_axis`. Example: ``` shard.grid @grid0(shape = 2x2) ... %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] reduction = scatter_axis = 0 - : tensor<3x4xf32> -> tensor<1x4xf64> + : tensor<2x2xf32> -> tensor<1x2xf64> ``` Input: ``` @@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" Result: ``` +-------+ - | 6 8 | <- devices (0, 0) + | 5 6 | <- devices (0, 0) +-------+ - | 10 12 | <- devices (0, 1) + | 7 8 | <- devices (0, 1) +-------+ - | 22 24 | <- devices (1, 0) + | 13 14 | <- devices (1, 0) +-------+ - | 26 28 | <- devices (1, 1) + | 15 16 | <- devices (1, 1) +-------+ ``` }]; @@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ ]> { let summary = "Scatter over a device grid."; let description = [{ - For each device group split the input tensor on the `root` device along - axis `scatter_axis` and scatter the parts across the group devices. + For each device group defined by `grid_axes`, the input tensor on the `root` + device is split along axis `scatter_axis` and distributed across the group. + The content of the input on all other (non-root) devices is ignored. + The `root` device is defined by its in-group multi-index. Example: ``` @@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ (0, 1) ↓ +-------+-------+ | scatter tensor - device (0, 0) -> | | | | axis 0 - | | | ↓ + device (0, 0) -> | * * | * * | | axis 0 + | * * | * * | ↓ +-------+-------+ device (1, 0) -> | 1 2 | 5 6 | | 3 4 | 7 8 | @@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [ ]> { let summary = "Send over a device grid."; let description = [{ - Send from one device to another within a device group. + Send input tensor to device `destination`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [ ]> { let summary = "Shift over a device grid."; let description = [{ - Within each device group shift along grid axis `shift_axis` by an offset - `offset`. - The result on devices that do not have a corresponding source is undefined. - `shift_axis` must be one of `grid_axes`. - If the `rotate` attribute is present, - instead of a shift a rotation is done. + Within each device group defined by `grid_axes`, shifts input tensors along the + device grid's axis `shift_axis` by the specified offset. The `shift_axis` must + be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular. + That is, the offset wraps around according to the group size along `shift_axis`. + Otherwise, the results on devices without a corresponding source are undefined. Example: ``` From e8baffe5006e90b9904a7558835ea8e7ee64c482 Mon Sep 17 00:00:00 2001 From: Frank Schlimbach Date: Thu, 16 Oct 2025 15:43:59 +0200 Subject: [PATCH 2/2] Update mlir/include/mlir/Dialect/Shard/IR/ShardOps.td Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mlir/include/mlir/Dialect/Shard/IR/ShardOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index 60461b9ddc826..5e68f75ee08bf 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -711,7 +711,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ ]> { let summary = "Broadcast over a device grid."; let description = [{ - Copies the input tensor on `root` to the all devices in each group defined by + Copies the input tensor on `root` to all devices in each group defined by `grid_axes`. The `root` device is defined by its in-group multi-index. The contents of input tensors on non-root devices are ignored.