Skip to content

Commit

Permalink
[mlir][mesh] Add collective communication operations (#71960)
Browse files Browse the repository at this point in the history
Add all-gather, all-reduce, all-to-all and reduce-scatter. These
operations have device mesh semantics.
  • Loading branch information
sogartar committed Nov 21, 2023
1 parent ac75171 commit 5f7c8c1
Show file tree
Hide file tree
Showing 8 changed files with 1,179 additions and 3 deletions.
43 changes: 43 additions & 0 deletions mlir/docs/Dialects/Mesh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# 'mesh' Dialect

The `mesh` dialect contains a set of attributes, operations and interfaces that
are useful for representing sharding and communication on a device mesh
cluster.

[TOC]

## Collective Communication Operations
There are a number of operations in the Mesh dialect to facilitate
communication between devices in a mesh.
It is assumed that the user is familiar with collective operations.
[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
explanation.
The main addition is that the collectives in this dialect have mesh
semantics.

The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
Devices that have the same coordinates outside of axes `mesh_axes` are in the
same group.
For example if we have a device mesh of size `2x3x4x5` and the partition mesh
axes list is `[0, 1]` then devices are partitioned into the groups
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
order of devices.
The order of device in a device group is induced by the order of axes in
`mesh_axes`.
The axes are ordered from outer to inner.
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.


## Operations

[include "Dialects/MeshOps.md"]

## Attributes

[include "Dialects/MeshAttributes.md"]
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
let cppNamespace = "::mlir::mesh";

let description = [{
The `mesh` dialect contains a set of attributes, operations, interfaces that
are useful for representing sharding and communication on device mesh
cluster.
See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
}];

let dependentDialects = [
Expand All @@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}

def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
let assemblyFormat = "`<` $value `>`";
}

// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <algorithm>

#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"

Expand Down
229 changes: 229 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/SymbolInterfaces.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -77,6 +79,18 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
attr-dict
}];
let extraClassDeclaration = [{
// The `dim_sizes` attribute may have size less than the rank of the mesh.
// Returns the shape of the mesh with missing trailing dimensions
// explicitly set as dynamic.
::mlir::SmallVector<int64_t> canonicalDimSizes();

template <typename OutIt>
void canonicalDimSizes(OutIt outIt) {
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
}
}];
let hasVerifier = 1;
}

Expand Down Expand Up @@ -171,4 +185,219 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}

//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//

class Mesh_CollectiveCommunicationOpBase<
string mnemonic, list<Trait> traits = []> :
Mesh_Op<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
);
}

def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
let summary = "All-gather over a device mesh.";
let description = [{
Gathers along the `gather_axis` tensor axis.

Example:
```mlir
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
```
Input:
```
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
```
Result:
```
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$gather_axis
));
let results = (outs
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}

def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device mesh.";
let description = [{
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.

Attributes:
`reduction`: Indicates the reduction method.

Example:
```
%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
```
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}

def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device mesh.";
let description = [{
Performs an all-to-all on tensor pieces split along `split_axis`.
The resulting pieces are concatenated along `concat_axis` on ech device.

Example:
```
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
```
Input:
```
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 | ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+
```
Result:
```
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$split_axis,
IndexAttr:$concat_axis
));
let results = (outs
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}

def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device mesh.";
let description = [{
After the reduction, the result is scattered within each device group.
The tensor is split along `scatter_axis` and the pieces distributed
across the device group.
Example:
```
mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
```
Input:
```
device
(0, 1)
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 | ↓
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
device
(1, 1)
```
Result:
```
+-------+
| 6 8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
IndexAttr:$scatter_axis
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}

#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
Loading

0 comments on commit 5f7c8c1

Please sign in to comment.