|
| 1 | +//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD |
| 10 | +#define MLIR_DIALECT_MESH_IR_MESHOPS_TD |
| 11 | + |
| 12 | +include "mlir/Dialect/Mesh/IR/MeshBase.td" |
| 13 | +include "mlir/Interfaces/InferTypeOpInterface.td" |
| 14 | +include "mlir/Interfaces/SideEffectInterfaces.td" |
| 15 | +include "mlir/IR/BuiltinTypes.td" |
| 16 | +include "mlir/IR/SymbolInterfaces.td" |
| 17 | + |
| 18 | +//===----------------------------------------------------------------------===// |
| 19 | +// Mesh Dialect operations. |
| 20 | +//===----------------------------------------------------------------------===// |
| 21 | + |
| 22 | +class Mesh_Op<string mnemonic, list<Trait> traits = []> : |
| 23 | + Op<Mesh_Dialect, mnemonic, traits> { |
| 24 | +} |
| 25 | + |
| 26 | +def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { |
| 27 | + let summary = "representing a mesh cluster"; |
| 28 | + let description = [{ |
| 29 | + The mesh.cluster operation is a symbol operation that identifies a specific |
| 30 | + mesh cluster. The operation has three attributes: |
| 31 | + |
| 32 | + 1. `sym_name`: This attribute uniquely identifies the name of the mesh |
| 33 | + cluster. This name serves as a symbolic reference to the cluster throughout |
| 34 | + the MLIR module, allowing for consistent referencing and easier debugging. |
| 35 | + |
| 36 | + 2. `rank`: This attribute specifies the number of axes of the cluster. The |
| 37 | + rank indicates the dimensionality of the mesh cluster and can be used to |
| 38 | + determine the layout and the addressing space of the computation distributed |
| 39 | + across the mesh. |
| 40 | + |
| 41 | + 3. `dim_sizes`: This attribute represents the device assignment along the |
| 42 | + axes of the cluster. Each integer in the array corresponds to the number of |
| 43 | + devices along a specific axis. If an integer value is 0, it implies that the |
| 44 | + number of devices along that axis is unknown. This flexibility allows for |
| 45 | + dynamic device assignment or configurations where the exact number of |
| 46 | + devices might not be determined during compile time. |
| 47 | + |
| 48 | + Example: |
| 49 | + ``` |
| 50 | + // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12 |
| 51 | + // The dimension sizes are 4, 8, 12 |
| 52 | + mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12]) |
| 53 | + |
| 54 | + // A device mesh cluster with 2 axes, the total device number is unknown |
| 55 | + // The first dimension size is 4 and the second is unknown |
| 56 | + mesh.cluster @mesh1(rank = 2, dim_sizes = [4]) |
| 57 | + |
| 58 | + // A device mesh cluster with 2 axes, the total device number is unknown |
| 59 | + // The first dimension size is unknown and the second is 4 |
| 60 | + mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4]) |
| 61 | + |
| 62 | + // A device mesh cluster with 2 axes, the number of devices along both axes |
| 63 | + // is unknown |
| 64 | + mesh.cluster @mesh3(rank = 2) |
| 65 | + |
| 66 | + // Used in the mesh sharding attribute to extend the standard tensor to |
| 67 | + // distributed |
| 68 | + tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> |
| 69 | + ``` |
| 70 | + }]; |
| 71 | + let arguments = (ins |
| 72 | + SymbolNameAttr:$sym_name, |
| 73 | + I8Attr:$rank, |
| 74 | + DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes |
| 75 | + ); |
| 76 | + let assemblyFormat = [{ |
| 77 | + $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)` |
| 78 | + attr-dict |
| 79 | + }]; |
| 80 | + let hasVerifier = 1; |
| 81 | +} |
| 82 | + |
| 83 | +def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> { |
| 84 | + let summary = "Annotate on how a tensor is sharded across a mesh cluster."; |
| 85 | + let description = [{ |
| 86 | + The mesh.shard operation is designed to specify and guide the sharding |
| 87 | + behavior of a tensor value across a mesh topology. This operation has one |
| 88 | + operand and two attributes: |
| 89 | + |
| 90 | + 1. `input`: This operand represents the tensor value that needs to be |
| 91 | + annotated for sharding. |
| 92 | + |
| 93 | + 2. `shard`: This attribute is type of `MeshSharding`, which is the core data |
| 94 | + structure to represent distributed tensor in mesh cluster. |
| 95 | + |
| 96 | + 3. `annotate_for_users`: A unit attribute addressing the scenario when a |
| 97 | + tensor's sharding annotation differs based on its context of use (either as |
| 98 | + a result or an operand). If specified, the sharding pertains to specific |
| 99 | + users of the tensor value, indicating how it should be considered when used |
| 100 | + as an operand in subsequent operations. If not, the sharding applies to the |
| 101 | + operation that defines the tensor value. |
| 102 | + |
| 103 | + Example: |
| 104 | + ``` |
| 105 | + func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { |
| 106 | + %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> |
| 107 | + ... |
| 108 | + } |
| 109 | + |
| 110 | + func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () { |
| 111 | + %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32> |
| 112 | + ... |
| 113 | + } |
| 114 | + |
| 115 | + // The first mesh.shard op applies to %arg0, the second mesh.shard op |
| 116 | + // applies for the operand of op0, the third mesh.shard op applies for the |
| 117 | + // operand of op2 |
| 118 | + func.func @both_result_and_multi_operands_annotated( |
| 119 | + %arg0 : tensor<4x8xf32>) -> () { |
| 120 | + %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> |
| 121 | + %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32> |
| 122 | + %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32> |
| 123 | + "op0"(%1) : ... |
| 124 | + "op1"(%2) : ... |
| 125 | + ... |
| 126 | + } |
| 127 | + ``` |
| 128 | + |
| 129 | + The following usages are undefined: |
| 130 | + ``` |
| 131 | + func.func @annotate_on_same_result_with_different_sharding( |
| 132 | + %arg0 : tensor<4x8xf32>) -> () { |
| 133 | + %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> |
| 134 | + %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32> |
| 135 | + ... |
| 136 | + } |
| 137 | + |
| 138 | + func.func @annotate_on_same_result_same_value_with_different_sharding( |
| 139 | + %arg0 : tensor<4x8xf32>) -> () { |
| 140 | + %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> |
| 141 | + %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32> |
| 142 | + ... |
| 143 | + } |
| 144 | + |
| 145 | + func.func @annotate_on_same_operand_with_different_sharding( |
| 146 | + %arg0 : tensor<4x8xf32>) -> () { |
| 147 | + %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32> |
| 148 | + %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32> |
| 149 | + ... |
| 150 | + } |
| 151 | + |
| 152 | + func.func @result_annotated_after_operand( |
| 153 | + %arg0 : tensor<4x8xf32>) -> () { |
| 154 | + %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32> |
| 155 | + %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32> |
| 156 | + ... |
| 157 | + } |
| 158 | + ``` |
| 159 | + }]; |
| 160 | + let arguments = (ins |
| 161 | + Builtin_RankedTensor:$src, |
| 162 | + MeshSharding:$shard, |
| 163 | + UnitAttr:$annotate_for_users |
| 164 | + ); |
| 165 | + let results = (outs |
| 166 | + Builtin_RankedTensor:$result |
| 167 | + ); |
| 168 | + let assemblyFormat = [{ |
| 169 | + $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:` |
| 170 | + type($result) |
| 171 | + }]; |
| 172 | +} |
| 173 | + |
| 174 | +#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD |
0 commit comments