-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh] Add all-scatter operation #81218
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir-core Author: Boian Petkantchin (sogartar) ChangesThis op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic. Add lowering for the op that slices from the tensor based on the in-group process index. Make resharding generate an all-scatter instead of inserting the slicing logic directly. Patch is 55.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81218.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
class DominanceInfo;
class Operation;
class PostDominanceInfo;
+class ImplicitLocOpBuilder;
namespace func {
class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder);
/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
namespace mlir {
@@ -81,6 +82,22 @@ struct ArithBuilder {
OpBuilder &b;
Location loc;
};
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType);
+} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+ return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+ mesh.getShape());
+}
+
// Get the size of a sharded dimension.
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
let builders = [
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+ Pure,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+ let description = [{
+ Scatter along the `scatter_axis` tensor axis.
+ This operation can be taught of as the inverse of all-gather.
+ Technically, it is not required that all processes have the same input tensor.
+ Each process will slice a piece of its local tensor based on its in-group device index.
+ The operation does not communicate data between devices.
+
+ Example:
+ ```mlir
+ mesh.mesh @mesh0(shape = 2x2)
+ ...
+ %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+ : tensor<2x4xi8> -> tensor<2x2xi8>
+ ```
+ Input:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ Result:
+ ```
+ gather tensor
+ axis 1
+ ------------>
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ IndexAttr:$scatter_axis
+ ));
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+ OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+ ];
+}
+
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Pure,
SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
namespace mlir {
class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
+class ImplicitLocOpBuilder;
namespace mesh {
void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry);
+void allScatterOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry);
+
+void populateAllOpLoweringPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry ®istry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+ ImplicitLocOpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
// ranked tensor of them. Returns null attribute otherwise.
TypedAttr getZeroAttr(Type type);
+ // Returns a 1-valued attribute of the given `type`.
+ // Type constraints are the same as `getZeroAttr`.
+ TypedAttr getOneAttr(Type type);
// Convenience methods for fixed types.
FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
results.push_back(residual);
return results;
}
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder) {
+ assert(multiIndex.size() == shape.size());
+ SmallVector<AffineExpr> shapeAffine;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+ }
+
+ SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(stridesAffine.size());
+ llvm::transform(stridesAffine, std::back_inserter(strides),
+ [&builder, &shape](AffineExpr strideExpr) {
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), strideExpr, shape);
+ });
+
+ auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+ OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
using namespace mlir;
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
}
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+ return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType) {
+ Value one = builder.create<ConstantOp>(loc, resultType,
+ builder.getOneAttr(resultType));
+ ArithBuilder arithBuilder(builder, loc);
+ return std::accumulate(
+ values.begin(), values.end(), one,
+ [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshOp mesh) {
+ build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ MeshOp mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
- mesh.getSymName(),
- MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+ SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+ odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
+ assert(!axes.empty());
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
return success();
}
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ RankedTensorType operandRankedTensorType =
+ operandType.cast<RankedTensorType>();
+ DimensionSize operandScatterAxisSize =
+ operandRankedTensorType.getShape()[scatterAxis];
+ SmallVector<int64_t> resultShape =
+ llvm::to_vector(operandRankedTensorType.getShape());
+
+ resultShape[scatterAxis] =
+ operandScatterAxisSize /
+ DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+ return operandRankedTensorType.clone(resultShape);
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_gather op
//===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyScatterOperandAndResultShape(
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ Type resultType =
+ scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+ build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+ scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Type resultType, Value input, StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+ build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+ APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
LINK_LIBS PUBLIC
MLIRAffineDialect
+ MLIRAffineUtils
MLIRArithDialect
+ MLIRArithUtils
MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include <iterator>
#include <numeric>
#include <utility>
@@ -56,13 +56,10 @@ namespace {
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
// pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
- template <typename... OpRewritePatternArgs>
- MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
- OpRewritePatternArgs &&...opRewritePatternArgs)
- : OpRewritePattern(
- std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
- symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+ : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(MeshShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
return success();
}
-
-private:
- SymbolTableCollection &symbolTableCollection;
};
} // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
- int64_t splitTensorAxis,
- int64_t splitCount) {
- SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
- targetShape[splitTensorAxis] =
- shardDimension(targetShape[splitTensorAxis], splitCount);
- return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
// Split a replicated tensor along a mesh axis.
// e.g. [[0, 1]] -> [[0, 1, 2]].
// Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- MLIRContext *ctx = builder.getContext();
- builder.setInsertionPointAfterValue(sourceShard);
-
- Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
- Value processIndexAlongAxis =
+ TypedValue<ShapedType> targetShard =
builder
- .create<ProcessMultiIndexOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
- .getResult()[0];
-
+ .create<AllScatterOp>(sourceShard, mesh,
+ ArrayRef<MeshAxis>(splitMeshAxis),
+ splitTensorAxis)
+ .ge...
[truncated]
|
I am not sure what is the best name for this operation. Scatter is the inverse of gather. It makes sense that the inverse of all-gather would be all-scatter. But then all-scatter implies that all devices scatter their tensor, when we are actually slicing. |
@yaochengji, can you review this? |
Here is a bit more context from a previous PR |
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ | |||
let hasCanonicalizer = 1; | |||
} | |||
|
|||
def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if it's appropriate to include scatter in the operation's name, as scatter typically implies that the operation involves communication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. Maybe all_slice
or something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rename it to all_slice
.
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); }); | ||
} | ||
|
||
} // namespace mlir::arith |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: new line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added one. I thought the formatter would take care of this.
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); | |||
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc, | |||
Value linearIndex, | |||
ArrayRef<Value> basis); | |||
// Generate IR that extracts the linear index form a multi-index according to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: s/form/from/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); | |||
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc, | |||
Value linearIndex, | |||
ArrayRef<Value> basis); | |||
// Generate IR that extracts the linear index form a multi-index according to | |||
// a shape. | |||
OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
linearizeIndex
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed it. delinearizeIndex
above it uses shape/basis as well. I changed the argument to basis for consistency.
There is another function computeLinearIndex
in indexing utils that uses stride instead of basis.
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ | |||
let hasCanonicalizer = 1; | |||
} | |||
|
|||
def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. Maybe all_slice
or something like that?
let summary = "All-scatter over a device mesh. This is the inverse of all-gather."; | ||
let description = [{ | ||
Scatter along the `scatter_axis` tensor axis. | ||
This operation can be taught of as the inverse of all-gather. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: s/taught/thought/ ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry); | ||
|
||
void allScatterOpLoweringPopulatePatterns( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
populateAllScatterOpLoweringPatterns
to follow the baming convention?
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
ArrayRef<MeshAxis> meshAxes, | ||
int64_t scatterAxis) { | ||
RankedTensorType operandRankedTensorType = | ||
operandType.cast<RankedTensorType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This style of casting is deprecated? Prefer to use cast<RankedTensorType>(operandType)
.
@@ -46,6 +51,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> { | |||
builder.setInsertionPointAfter(op.getOperation()); | |||
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh); | |||
ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults(); | |||
llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug leftover?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I removed it.
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), | ||
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); | ||
|
||
// extract slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to flesh out the comment here a bit explaining the indeixng logic below--easier for readers to follow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some documentation at the beginning of the function describing the algorithm.
This op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic. Add lowering for the op that slices from the tensor based on the in-group process index. Make resharding generate an all-slice instead of inserting the slicing logic directly.
145ca9e
to
729d75b
Compare
Thank you, I squashed and rebased before merging. |
This op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic.
Add lowering for the op that slices from the tensor based on the in-group process index.
Make resharding generate an all-scatter instead of inserting the slicing logic directly.