-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][mesh] Add collective communication operations #71960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesAdd all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics. I have not included ops like reduce, gather, send and recv to see first if reviewers notice any systemic issues. Also this PR is already big enough. Patch is 43.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71960.diff 8 Files Affected:
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..6dd4f79022061ee
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,34 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes set is `{0, 1}` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
let cppNamespace = "::mlir::mesh";
let description = [{
- The `mesh` dialect contains a set of attributes, operations, interfaces that
- are useful for representing sharding and communication on device mesh
- cluster.
+ See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
}];
let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..7698d60813a8f10 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,11 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..15354babe870599 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -77,6 +79,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
attr-dict
}];
+ let extraClassDeclaration = [{
+ ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+ template <typename OutIt>
+ void canonicalDimSizes(OutIt outIt) {
+ std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+ std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+ }
+ }];
let hasVerifier = 1;
}
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+ string mnemonic, list<Trait> traits = []> :
+ Mesh_Op<mnemonic,
+ !listconcat(traits,
+ [SymbolUserOpInterface])> {
+ let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+ code extraClassDeclarationBase = [{
+ ::mlir::LogicalResult verifySymbolUses(
+ ::mlir::SymbolTableCollection &symbolTable);
+ }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-gather over a device mesh.";
+ let description = [{
+ Gathers along the `gather_axis` tensor axis.
+ The order of input tensors in the resulting tensor is the same as the
+ order of the corresponding devices' multi-index in the mesh.
+
+ Example:
+ ```mlir
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.all_gather %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+ } : tensor<2x2xi8> -> tensor<2x4xi8>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ APIntAttr:$gather_axis
+ );
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+ SameOperandsAndResultShape]> {
+ let summary = "All-reduce over a device mesh.";
+ let description = [{
+ The accumulation element type is specified by the result type and
+ it does not need to match the input element type.
+ The input element is converted to the result element type before
+ performing the reduction.
+
+ Attributes:
+ `reduction`: Indicates the reduction method.
+
+ Example:
+ ```
+ %1 = mesh.all_reduce %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+ } : tensor<3x4xf32> -> tensor<3x4xf64>
+ ```
+ }];
+ let arguments = (ins
+ AnyRankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
+ let summary = "All-to-all over a device mesh.";
+ let description = [{
+ Performs an all-to-all on tensor pieces split along `split_axis`.
+ The resulting pieces are concatenated along `concat_axis` on ech device.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+ ...
+ %1 = mesh.all_to_all %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+ } : tensor<3x6xi8> -> tensor<3x6xi8>
+ ```
+ Input:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
+ | 11 12 | 21 22 | 31 32 |
+ | 13 14 | 23 24 | 33 34 |
+ | 15 16 | 25 26 | 35 36 |
+ +-------+-------+-------+
+ ```
+ Result:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
+ | 11 12 | 13 14 | 15 16 |
+ | 21 22 | 23 24 | 25 26 |
+ | 31 32 | 33 34 | 35 36 |
+ +-------+-------+-------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ APIntAttr:$split_axis,
+ APIntAttr:$concat_axis
+ );
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+ SameOperandsAndResultRank]> {
+ let summary = "Reduce-scatter over a device mesh.";
+ let description = [{
+ After the reduction scatters the result within each device group.
+ The tensor is split along `scatter_axis` and the pieces distributed
+ across the device group.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.reduce_scatter %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
+ } : tensor<3x4xf32> -> tensor<1x4xf64>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------+
+ | 6 8 | <- devices (0, 0)
+ +-------+
+ | 10 12 | <- devices (0, 1)
+ +-------+
+ | 22 24 | <- devices (1, 0)
+ +-------+
+ | 26 28 | <- devices (1, 1)
+ +-------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ APIntAttr:$scatter_axis
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..6efc4c4ecc326ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,26 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
#define DEBUG_TYPE "mesh-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +37,60 @@ using namespace mlir::mesh;
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+namespace {
+
+template <typename It>
+It canonicalizeSetAsArray(It begin, It end) {
+ std::sort(begin, end);
+ return std::unique(begin, end);
+}
+
+template <typename R>
+auto canonicalizeSetAsArray(R &&range) {
+ return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+ auto newEnd = canonicalizeSetAsArray(vec);
+ vec.resize(newEnd - vec.begin());
+ return vec;
+}
+
+template <typename DimSize>
+bool isMeshDimensionDynamic(DimSize size) {
+ return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+struct DimensionSize {
+ static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+ DimensionSize(int64_t val) : val(val) {}
+ int64_t value() const { return val; }
+ operator int64_t() const { return val; }
+ bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+ int64_t val;
+};
+
+DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() / rhs.value();
+}
+
+DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() * rhs.value();
+}
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// Mesh dialect
//===----------------------------------------------------------------------===//
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
return success();
}
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+ SmallVector<int64_t> result;
+ canonicalDimSizes(std::back_inserter(result));
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+ if (!attr) {
+ return std::nullopt;
+ }
+ SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+ canonicalizeSetAsVector(axes);
+ if (axes.empty()) {
+ return std::nullopt;
+ }
+ return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+ AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+ : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+ op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+ if (!canonicalMeshAxesAttr) {
+ op->removeAttr(axisSetAttr);
+ } else {
+ op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+ }
+ return success();
+ }
+
+ std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+ FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+ if (!symbolAttr) {
+ return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+ }
+ SymbolTableCollection symbolTableCollection;
+ mesh::ClusterOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), symbolAttr);
+ if (!mesh) {
+ return op.emitError() << "Undefined required mesh symbol \""
+ << symbolAttr.getValue() << "\".";
+ }
+ DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+ if (!meshAxes) {
+ return success();
+ }
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : meshAxes.asArrayRef()) {
+ if (axis >= rank || axis < 0) {
+ return op.emitError()
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \""
+ << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+ using ElementType = std::decay_t<decltype(*begin)>;
+ return std::accumulate(begin, end, ElementType(1),
+ std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+ return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ int64_t res = 1;
+ for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+ if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+ continue;
+ }
+ if (isMeshDimensionDynamic(meshShape[axis])) {
+ return ShapedType::kDynamic;
+ }
+ res *= meshShape[axis];
+ }
+ return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+ int64_t expectedDimSize,
+ int64_t resultDimSize,
+ int64_t resultAxis) {
+ if (!ShapedType::isDynamic(resultDimSize) &&
+ expectedDimSize != resultDimSize) {
+ return emitError(loc) << "Dimension size mismatch for result axis "
+ << resultAxis << ". Expected "
+ << (ShapedType::isDynamic(expectedDimSize)
+ ? Twine("dynamic")
+ : Twine(expectedDimSize))
+ << ", but got " << resultDimSize << ".";
+ }
+
+ return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+ int64_t gatherAxis,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+ auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+ auto expectedResultDimSize =
+ axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+ SymbolTableCollection symbolTableCollection;
+ if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+ // We need to check the symbol here since this runs before
+ // SymbolUserOpInterface.
+ return failure();
+ }
+ return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+ auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+ auto gatherAxis = op.getGatherAxis().getSExtValue();
+ if (gatherAxis < 0 || gatherAxis >= rank) {
+ return op.emitError() << "Gather axis " << gatherAxis
+ << " is out of bounds [0, " << rank << ").";
+ }
+
+ auto mesh = getMesh(op);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyGatherOperandAndResultShape(op.getOperand(), op.get...
[truncated]
|
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() { | |||
return success(); | |||
} | |||
|
|||
SmallVector<int64_t> ClusterOp::canonicalDimSizes() { | |||
SmallVector<int64_t> result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmallVector<int64_t> result; | |
SmallVector<int64_t> result; | |
result.reserve(getRank()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
template <typename It> | ||
auto product(It begin, It end) { | ||
using ElementType = std::decay_t<decltype(*begin)>; | ||
return std::accumulate(begin, end, ElementType(1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to suggest to use static_cast<ElementType>(1)
in case some classes use explicit constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes, | ||
ArrayRef<int64_t> meshShape) { | ||
int64_t res = 1; | ||
for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on that there's no duplication in meshAxes
, it could change to
for (MeshAxis axis : meshAxes) {
if (isMeshDimensionDynamic(meshShape[axis])) {
return ShapedType::kDynamic;
}
res *= meshShape[axis];
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
return success(); | ||
} | ||
|
||
LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you make all the ccl ops inherit InferShapedTypeOpInterface
and implement verification based on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is the problem that InferShapedTypeOpInterface
does not assume that the operation is created. You only have the operands (Value
s). But we need to have the operation to find the referenced CusterOp
. The resolution of the referenced ClusterOp
is dependent on which block the collective operation is defined in.
I am not sure how to resolve this yet. It seems shape/type inference is unequipped to deal with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you get it referenced ClusterOp
by using symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(operands[0].getParentRegion()->getParentOp(), meshAttr)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically the operand can be defined higher up the tree then the cluster op.
Something like
{
%operand = ...
{
mesh.cluster ...
%res = mesh.all_reduce (%operand) ...
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what are the restrictions on symbol resolution. We may have to impose additional restrictions on the mesh operations to be able to use this.
It is likely no one else has ran into this problem. func.call
for example does not have type inference. It would have hit this issue if it did.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
func.call
doesn't need to have type inference because its inference logic depends on its body.
@joker-eph do you have any suggestion? The problem is that here we need to loopup the symbol without the operation.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
namespace { | ||
|
||
std::optional<DenseI16ArrayAttr> | ||
canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you regard mesh_axis
as a set? I think the duplication should not be allowed in verification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking to change its meaning. In the Mesh.md doc it is stated that the order of devices is the same as the order of their multi-index in the mesh. This would make it a set.
We can change the meaning where the order of devices could be induced by the order in mesh_axes
. For example if mesh_axes = [2, 1]
then device (i, j, k)
would precede device (i, l, m)
if k > m
. This would influence operations like all-gather where order device order is important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it seems your reply cannot explain why we have to make mesh_axes
support duplications like [2, 1, 2]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to use a set but there is no such attribute, so I went for an array. In math the set {2, 1, 2}
is the same as {2, 1}
. I can see why you want to forbid duplication as it may result from a programming error where the intention was something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think not-accepting {2, 1, 2}
could make the semantic more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
I also change the meaning of mesh_axes
to induce the order of devices in the device groups. I added documentation about this.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
} | ||
|
||
template <typename Op> | ||
LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you make verifyMeshSymbolUses
a non-template function? I'm afraid of code bloat.
The signature could change to verifyMeshSymbolUses(Operation*, FlatSymbolRefAttr, DenseI16ArrayAttr, SymbolTableCollection &)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think extracting attributes from the concrete operation type through class methods like op.getMeshAttr()
does not incur attribute lookup. Also there is some benefit to inlining this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you think this is not that important I can change it to accept Operation*
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be a bit nit-picky. It's a good point that inlining + template can improve performance at the expense of binary size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also add the specific attributes likes getMeshAttr()
to the API of the non-templated function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made getMesh
and verifyMeshSymbolUses
non-template.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
} | ||
|
||
template <typename Op> | ||
FailureOr<ClusterOp> getMesh(Op op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change getMesh
to a non-template function to avoid code bloat?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same reasoning as in the comment about verifyMeshSymbolUses
. Let's discuss there.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
return success(); | ||
} | ||
|
||
LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change name to verifyAllGatherOperandAndResultShape
because both gather
and all_gather
exist in ccl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking ahead that the verification of gather and all-gather will be the same. I renamed it to verifyAllGatherOperandAndResultShape
to avoid the confusion.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
} | ||
|
||
template <typename Op> | ||
LogicalResult verifyGather(Op op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change name to verifyAllGather
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
|
||
LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) { | ||
SmallVector<MeshAxis> sorted = llvm::to_vector(axes); | ||
std::sort(sorted.begin(), sorted.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to suggest to check std::set(axes.begin(), axes.end()).size() == axes.size()
instead so we don't need to sort axes. And the code will be more concise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually construction of std::set
would be more expensive then doing a copy + sort + linear search for consecutive duplicate elements. I have not measured it, but my intuition is that for small numbers like < 10 this would be true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, but I'd like suggest to use std::unique(begin, end) == end
instead of implementing isUnique
ourselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would even speculate that for all size a vector would be faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, this is fine to me.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
RewritePatternSet &patterns, MLIRContext *context) { | ||
populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context); | ||
LogicalResult mlir::mesh::AllReduceOp::verify() { | ||
return verifyMeshAxes(getLoc(), getMeshAxes()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should verify the shape of all_reduce.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has the trait SameOperandsAndResultShape
. It should take care of that verification.
What is not verified here is whether the referenced mesh has the required rank by the mesh_axes
attribute. I will add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, I actually do the check in verifyMeshSymbolUses
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean this check?
if (!meshAxes) {
return success();
}
But an ArrayAttr with an empty array is not nullptr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this check since I made the mesh_axes
attribute non-optional. It still has a default value.
return %0 : tensor<4xf64> | ||
} | ||
|
||
// CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If empty mesh_axes
is not allowed, we'd better check it in theverifyMeshAxes
method.
@sogartar seems you missed the 9 hidden coversations. |
::mlir::SmallVector<int64_t> canonicalDimSizes(); | ||
|
||
template <typename OutIt> | ||
void canonicalDimSizes(OutIt outIt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you document these methods? What is the "canonical" form?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyNon0RankedTensor:$input, | ||
APIntAttr:$gather_axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I32Attr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed all tensor axes to IndexAttr
on all occasions.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
SymbolTableCollection symbolTableCollection; | ||
mesh::ClusterOp mesh = | ||
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>( | ||
op.getOperation(), symbolAttr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be costly, what about verifySymbolUses
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is called form verifySymbolUses
and getMesh
. We need the mesh shape for verification. Do you mean to put all verification related to the mesh symbol in verifySymbolUses
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes: verifySymbolUses
is meant to be able to build a symbol table once and forall to avoid repeated traversal of the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All verifications are now in verifySymbolUses
.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
return failure(); | ||
} | ||
return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>( | ||
op.getOperation(), op.getMeshAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same as above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All verifications are now in verifySymbolUses
.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
} | ||
|
||
template <typename Op> | ||
LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also add the specific attributes likes getMeshAttr()
to the API of the non-templated function?
Improve ops doc.
It is non-optional now.
Thanks, LGTM |
Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.
I have not included ops like reduce, gather, send and recv to see first if reviewers notice any systemic issues. Also this PR is already big enough.