Skip to content

Commit 08545e8

Browse files
authored
[MLIR] Add a new Mesh dialect (#68007)
This is the 1st PR of [Mesh sharding RFC](https://discourse.llvm.org/t/open-mlir-meeting-9-28-2023-rfc-sharding-framework-design-for-device-mesh/73695), includes Includes: - mesh.cluster op - mesh.shard op (the mesh.annotate op in the RFC slides, the name is modified a bit from @stellaraccident 's advice, which I think might be a bit more concise) - MeshSharding attribute
1 parent ebea930 commit 08545e8

File tree

14 files changed

+686
-1
lines changed

14 files changed

+686
-1
lines changed

mlir/include/mlir/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_subdirectory(Linalg)
1919
add_subdirectory(LLVMIR)
2020
add_subdirectory(Math)
2121
add_subdirectory(MemRef)
22+
add_subdirectory(Mesh)
2223
add_subdirectory(MLProgram)
2324
add_subdirectory(NVGPU)
2425
add_subdirectory(OpenACC)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_mlir_dialect(MeshOps mesh)
2+
add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc)
3+
4+
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
5+
mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
6+
mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
7+
add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
8+
add_mlir_doc(MeshOps MeshAttributes Dialects/ -gen-attrdef-doc)
9+
10+
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
11+
mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
12+
mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
13+
add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
10+
#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
11+
12+
include "mlir/IR/OpBase.td"
13+
include "mlir/IR/AttrTypeBase.td"
14+
include "mlir/IR/BuiltinTypeInterfaces.td"
15+
include "mlir/IR/EnumAttr.td"
16+
17+
//===----------------------------------------------------------------------===//
18+
// Mesh Dialect
19+
//===----------------------------------------------------------------------===//
20+
21+
def Mesh_Dialect : Dialect {
22+
let name = "mesh";
23+
let cppNamespace = "::mlir::mesh";
24+
25+
let description = [{
26+
The `mesh` dialect contains a set of attributes, operations, interfaces that
27+
are useful for representing sharding and communication on device mesh
28+
cluster.
29+
}];
30+
31+
let dependentDialects = [
32+
"arith::ArithDialect" // For materializeConstant()
33+
];
34+
35+
let useDefaultAttributePrinterParser = 1;
36+
let hasConstantMaterializer = 1;
37+
}
38+
//===----------------------------------------------------------------------===//
39+
// Mesh Enums.
40+
//===----------------------------------------------------------------------===//
41+
42+
def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
43+
I32EnumAttrCase<"Sum", 1, "sum">,
44+
I32EnumAttrCase<"Max", 2, "max">,
45+
I32EnumAttrCase<"Min", 3, "min">,
46+
I32EnumAttrCase<"Generic", 100, "generic">
47+
]> {
48+
let genSpecializedAttr = 0;
49+
let cppNamespace = "::mlir::mesh";
50+
}
51+
52+
//===----------------------------------------------------------------------===//
53+
// Mesh Attribute
54+
//===----------------------------------------------------------------------===//
55+
56+
def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
57+
let mnemonic = "shard";
58+
59+
let parameters = (ins
60+
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
61+
ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
62+
OptionalArrayRefParameter<"int8_t">:$partial_axes,
63+
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
64+
);
65+
66+
let summary = "Attribute that extends tensor type to distributed tensor type.";
67+
68+
let description = [{
69+
The MeshSharding attribute could be used in the encoding of a
70+
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
71+
72+
1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
73+
cluster where the distributed tensor is placed. The symbol must resolve to a
74+
`mesh.cluster` operation.
75+
76+
2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
77+
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
78+
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
79+
along the x and y axes of the device mesh.
80+
81+
3. `partial_axes`: if not empty, this signifies that the tensor is partial
82+
one along the specified mesh axes. An all-reduce should be applied to obtain
83+
the complete tensor, with reduction type being specified by `partial_type`.
84+
85+
4. `partial_type`: indicates the reduction type of the possible all-reduce
86+
op. It has 4 possible values:
87+
- `partial_sum`: denotes it's an all-reduce-sum
88+
- `partial_max`: denotes it's an all-reduce-max
89+
- `partial_min`: denotes it's an all-reduce-min
90+
- `partial_generic`: denotes that the all-reduce type is complex and cannot
91+
be represented merely by a simple sum, max, or min. The exact reduction
92+
computation may be derived from the semantics of the corresponding operation
93+
or from the reduction computation IR
94+
95+
Example:
96+
97+
```
98+
mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
99+
100+
// The tensor is fully replicated on @mesh0.
101+
// Currently, there must be at least one sub-array present in axes, even
102+
// if it's empty. Otherwise, a parsing error will occur.
103+
tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
104+
105+
// The tensor is sharded on the first dimension along axis 0 of @mesh0
106+
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
107+
108+
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
109+
// it is also a partial_sum along mesh axis 1.
110+
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
111+
112+
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
113+
// it is also a partial_max along mesh axis 1.
114+
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
115+
116+
// Could be used in the attribute of mesh.shard op
117+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
118+
```
119+
}];
120+
let assemblyFormat = [{
121+
`<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
122+
$partial_axes^ `]`)? `>`
123+
}];
124+
125+
let genVerifyDecl = 1;
126+
}
127+
128+
#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10+
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/IR/SymbolTable.h"
14+
#include "mlir/Interfaces/InferTypeOpInterface.h"
15+
#include "mlir/Interfaces/SideEffectInterfaces.h"
16+
17+
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
18+
19+
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
20+
21+
#define GET_ATTRDEF_CLASSES
22+
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
23+
24+
#define GET_OP_CLASSES
25+
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
26+
27+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD
10+
#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
11+
12+
include "mlir/Dialect/Mesh/IR/MeshBase.td"
13+
include "mlir/Interfaces/InferTypeOpInterface.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
15+
include "mlir/IR/BuiltinTypes.td"
16+
include "mlir/IR/SymbolInterfaces.td"
17+
18+
//===----------------------------------------------------------------------===//
19+
// Mesh Dialect operations.
20+
//===----------------------------------------------------------------------===//
21+
22+
class Mesh_Op<string mnemonic, list<Trait> traits = []> :
23+
Op<Mesh_Dialect, mnemonic, traits> {
24+
}
25+
26+
def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
27+
let summary = "representing a mesh cluster";
28+
let description = [{
29+
The mesh.cluster operation is a symbol operation that identifies a specific
30+
mesh cluster. The operation has three attributes:
31+
32+
1. `sym_name`: This attribute uniquely identifies the name of the mesh
33+
cluster. This name serves as a symbolic reference to the cluster throughout
34+
the MLIR module, allowing for consistent referencing and easier debugging.
35+
36+
2. `rank`: This attribute specifies the number of axes of the cluster. The
37+
rank indicates the dimensionality of the mesh cluster and can be used to
38+
determine the layout and the addressing space of the computation distributed
39+
across the mesh.
40+
41+
3. `dim_sizes`: This attribute represents the device assignment along the
42+
axes of the cluster. Each integer in the array corresponds to the number of
43+
devices along a specific axis. If an integer value is 0, it implies that the
44+
number of devices along that axis is unknown. This flexibility allows for
45+
dynamic device assignment or configurations where the exact number of
46+
devices might not be determined during compile time.
47+
48+
Example:
49+
```
50+
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
51+
// The dimension sizes are 4, 8, 12
52+
mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
53+
54+
// A device mesh cluster with 2 axes, the total device number is unknown
55+
// The first dimension size is 4 and the second is unknown
56+
mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
57+
58+
// A device mesh cluster with 2 axes, the total device number is unknown
59+
// The first dimension size is unknown and the second is 4
60+
mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
61+
62+
// A device mesh cluster with 2 axes, the number of devices along both axes
63+
// is unknown
64+
mesh.cluster @mesh3(rank = 2)
65+
66+
// Used in the mesh sharding attribute to extend the standard tensor to
67+
// distributed
68+
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
69+
```
70+
}];
71+
let arguments = (ins
72+
SymbolNameAttr:$sym_name,
73+
I8Attr:$rank,
74+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
75+
);
76+
let assemblyFormat = [{
77+
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
78+
attr-dict
79+
}];
80+
let hasVerifier = 1;
81+
}
82+
83+
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
84+
let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
85+
let description = [{
86+
The mesh.shard operation is designed to specify and guide the sharding
87+
behavior of a tensor value across a mesh topology. This operation has one
88+
operand and two attributes:
89+
90+
1. `input`: This operand represents the tensor value that needs to be
91+
annotated for sharding.
92+
93+
2. `shard`: This attribute is type of `MeshSharding`, which is the core data
94+
structure to represent distributed tensor in mesh cluster.
95+
96+
3. `annotate_for_users`: A unit attribute addressing the scenario when a
97+
tensor's sharding annotation differs based on its context of use (either as
98+
a result or an operand). If specified, the sharding pertains to specific
99+
users of the tensor value, indicating how it should be considered when used
100+
as an operand in subsequent operations. If not, the sharding applies to the
101+
operation that defines the tensor value.
102+
103+
Example:
104+
```
105+
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
106+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
107+
...
108+
}
109+
110+
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
111+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
112+
...
113+
}
114+
115+
// The first mesh.shard op applies to %arg0, the second mesh.shard op
116+
// applies for the operand of op0, the third mesh.shard op applies for the
117+
// operand of op2
118+
func.func @both_result_and_multi_operands_annotated(
119+
%arg0 : tensor<4x8xf32>) -> () {
120+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
121+
%1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
122+
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
123+
"op0"(%1) : ...
124+
"op1"(%2) : ...
125+
...
126+
}
127+
```
128+
129+
The following usages are undefined:
130+
```
131+
func.func @annotate_on_same_result_with_different_sharding(
132+
%arg0 : tensor<4x8xf32>) -> () {
133+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
134+
%1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
135+
...
136+
}
137+
138+
func.func @annotate_on_same_result_same_value_with_different_sharding(
139+
%arg0 : tensor<4x8xf32>) -> () {
140+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
141+
%1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
142+
...
143+
}
144+
145+
func.func @annotate_on_same_operand_with_different_sharding(
146+
%arg0 : tensor<4x8xf32>) -> () {
147+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
148+
%1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
149+
...
150+
}
151+
152+
func.func @result_annotated_after_operand(
153+
%arg0 : tensor<4x8xf32>) -> () {
154+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
155+
%1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
156+
...
157+
}
158+
```
159+
}];
160+
let arguments = (ins
161+
Builtin_RankedTensor:$src,
162+
MeshSharding:$shard,
163+
UnitAttr:$annotate_for_users
164+
);
165+
let results = (outs
166+
Builtin_RankedTensor:$result
167+
);
168+
let assemblyFormat = [{
169+
$src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
170+
type($result)
171+
}];
172+
}
173+
174+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ template <typename AsmPrinterT, typename T,
350350
!std::is_convertible<T &, Attribute &>::value &&
351351
!std::is_convertible<T &, ValueRange>::value &&
352352
!std::is_convertible<T &, APFloat &>::value &&
353-
!llvm::is_one_of<T, bool, float, double>::value,
353+
!llvm::is_one_of<T, bool, int8_t, uint8_t, float,
354+
double>::value,
354355
T> * = nullptr>
355356
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
356357
AsmPrinterT &>
@@ -366,6 +367,17 @@ operator<<(AsmPrinterT &p, bool value) {
366367
return p << (value ? StringRef("true") : "false");
367368
}
368369

370+
/// Specialization for 8-bit integers to ensure values are printed as integers
371+
// and not characters.
372+
template <
373+
typename AsmPrinterT, typename T,
374+
std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
375+
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
376+
AsmPrinterT &>
377+
operator<<(AsmPrinterT &p, T value) {
378+
return p << static_cast<int16_t>(value);
379+
}
380+
369381
template <typename AsmPrinterT, typename ValueRangeT>
370382
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
371383
AsmPrinterT &>

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
5656
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
5757
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
58+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
5859
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
5960
#include "mlir/Dialect/OpenACC/OpenACC.h"
6061
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -117,6 +118,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
117118
LLVM::LLVMDialect,
118119
math::MathDialect,
119120
memref::MemRefDialect,
121+
mesh::MeshDialect,
120122
ml_program::MLProgramDialect,
121123
nvgpu::NVGPUDialect,
122124
NVVM::NVVMDialect,

0 commit comments

Comments
 (0)