diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index b3ccbff3002fb..67c7a964feefd 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -15,12 +15,14 @@ #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/OpDefinition.h" #include namespace mlir { class DominanceInfo; class Operation; class PostDominanceInfo; +class ImplicitLocOpBuilder; namespace func { class FuncOp; @@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); FailureOr> delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis); +// Generate IR that extracts the linear index from a multi-index according to +// a basis/shape. +OpFoldResult linearizeIndex(ArrayRef multiIndex, + ArrayRef basis, + ImplicitLocOpBuilder &builder); /// Ensure that all operations that could be executed after `start` /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 402bd196f0736..2111a7c581029 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -20,6 +20,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/ArrayRef.h" namespace mlir { @@ -81,6 +82,22 @@ struct ArithBuilder { OpBuilder &b; Location loc; }; + +namespace arith { + +// Build the product of a sequence. +// If values = (v0, v1, ..., vn) than the returned +// value is v0 * v1 * ... * vn. +// All values must have the same type. +// +// The version without `resultType` must contain at least one element in values. +// Then the result will have the same type as the elements in `values`. +// If `values` is empty in the version with `resultType` returns 1 with type +// `resultType`. +Value createProduct(OpBuilder &builder, Location loc, ArrayRef values); +Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, + Type resultType); +} // namespace arith } // namespace mlir #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 9154e6fd80310..fb9425b96e68e 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, return res; } +template +int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) { + return collectiveProcessGroupSize(std::forward(meshAxes), + mesh.getShape()); +} + // Get the size of a sharded dimension. inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) { if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount)) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index da372706ec724..96636d5347ff6 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [ let builders = [ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>, + OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef":$axes)>, OpBuilder<(ins "StringRef":$mesh, "ArrayRef":$axes)> ]; } @@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ let hasCanonicalizer = 1; } +def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [ + Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultRank + ]> { + let summary = "All-slice over a device mesh. This is the inverse of all-gather."; + let description = [{ + Slice along the `slice_axis` tensor axis. + This operation can be thought of as the inverse of all-gather. + Technically, it is not required that all processes have the same input tensor. + Each process will slice a piece of its local tensor based on its in-group device index. + The operation does not communicate data between devices. + + Example: + ```mlir + mesh.mesh @mesh0(shape = 2x2) + ... + %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1 + : tensor<2x4xi8> -> tensor<2x2xi8> + ``` + Input: + ``` + +-------------+ + | 1 2 5 6 | <- devices (0, 0) and (0, 1) + | 3 4 7 8 | + +-------------+ + | 9 10 13 14 | <- devices (1, 0) and (1, 1) + | 11 12 15 16 | + +-------------+ + ``` + Result: + ``` + gather tensor + axis 1 + ------------> + +-------+-------+ + device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) + | 3 4 | 7 8 | + +-------+-------+ + device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) + | 11 12 | 15 16 | + +-------+-------+ + ``` + }]; + let arguments = !con(commonArgs, (ins + AnyNon0RankedTensor:$input, + IndexAttr:$slice_axis + )); + let results = (outs + AnyNon0RankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis + attr-dict `:` type($input) `->` type($result) + }]; + let hasCanonicalizer = 1; + let builders = [ + OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef":$meshAxes, "int64_t":$sliceAxis)>, + OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef":$meshAxes, "int64_t":$sliceAxis)> + ]; +} + def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ Pure, SameOperandsAndResultElementType, diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h index 10a965daac71b..aeab28961a4e1 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h @@ -9,16 +9,33 @@ #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + namespace mlir { class RewritePatternSet; class SymbolTableCollection; class DialectRegistry; +class ImplicitLocOpBuilder; namespace mesh { -void processMultiIndexOpLoweringPopulatePatterns( +void populateProcessMultiIndexOpLoweringPatterns( + RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); +void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry); + +void populateAllSliceOpLoweringPatterns( + RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); +void registerAllSliceOpLoweringDialects(DialectRegistry ®istry); + +void populateAllOpLoweringPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); +void registerAllOpLoweringDialects(DialectRegistry ®istry); -void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry); +TypedValue +createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, + ImplicitLocOpBuilder &builder); } // namespace mesh } // namespace mlir diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 2fe1495b2b593..43b6d2b384169 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -118,6 +118,9 @@ class Builder { // supports boolean, integer, and 16-/32-/64-bit float types, and vector or // ranked tensor of them. Returns null attribute otherwise. TypedAttr getZeroAttr(Type type); + // Returns a 1-valued attribute of the given `type`. + // Type constraints are the same as `getZeroAttr`. + TypedAttr getOneAttr(Type type); // Convenience methods for fixed types. FloatAttr getF16FloatAttr(float value); diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 4d4adb94a9fc8..3dc5539cde3d9 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -20,9 +20,11 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include @@ -1869,3 +1871,27 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, results.push_back(residual); return results; } + +OpFoldResult mlir::affine::linearizeIndex(ArrayRef multiIndex, + ArrayRef basis, + ImplicitLocOpBuilder &builder) { + assert(multiIndex.size() == basis.size()); + SmallVector basisAffine; + for (size_t i = 0; i < basis.size(); ++i) { + basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext())); + } + + SmallVector stridesAffine = computeStrides(basisAffine); + SmallVector strides; + strides.reserve(stridesAffine.size()); + llvm::transform(stridesAffine, std::back_inserter(strides), + [&builder, &basis](AffineExpr strideExpr) { + return affine::makeComposedFoldedAffineApply( + builder, builder.getLoc(), strideExpr, basis); + }); + + auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex( + OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex); + return affine::makeComposedFoldedAffineApply( + builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides); +} diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index bf274d4ae27ed..aa239f5e05396 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/SmallBitVector.h" +#include using namespace mlir; @@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) { Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { return b.create(loc, cmp, lhs, rhs); } + +namespace mlir::arith { + +Value createProduct(OpBuilder &builder, Location loc, ArrayRef values) { + return createProduct(builder, loc, values, values.front().getType()); +} + +Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, + Type resultType) { + Value one = builder.create(loc, resultType, + builder.getOneAttr(resultType)); + ArithBuilder arithBuilder(builder, loc); + return std::accumulate( + values.begin(), values.end(), one, + [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); }); +} + +} // namespace mlir::arith diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index a65b8f2e5a237..3291010d27428 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, MeshOp mesh) { + build(odsBuilder, odsState, mesh, SmallVector()); +} + +void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + MeshOp mesh, ArrayRef axes) { build(odsBuilder, odsState, - SmallVector(mesh.getRank(), odsBuilder.getIndexType()), - mesh.getSymName(), - MeshAxesAttr::get(odsBuilder.getContext(), SmallVector())); + SmallVector(axes.empty() ? mesh.getRank() : axes.size(), + odsBuilder.getIndexType()), + mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes)); } void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, StringRef mesh, ArrayRef axes) { + assert(!axes.empty()); build(odsBuilder, odsState, SmallVector(axes.size(), odsBuilder.getIndexType()), mesh, MeshAxesAttr::get(odsBuilder.getContext(), axes)); @@ -552,13 +558,13 @@ static LogicalResult verifyAllToAllOperandAndResultShape( return success(); } -static LogicalResult verifyScatterOperandAndResultShape( - Value operand, Value result, int64_t scatterAxis, +static LogicalResult verifyScatterOrSliceOperandAndResultShape( + Value operand, Value result, int64_t tensorAxis, ArrayRef meshAxes, ArrayRef meshShape) { ShapedType operandType = operand.getType().cast(); ShapedType resultType = result.getType().cast(); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { - if (axis != scatterAxis) { + if (axis != tensorAxis) { if (failed(verifyDimensionCompatibility( result.getLoc(), operandType.getDimSize(axis), resultType.getDimSize(axis), axis))) { @@ -570,26 +576,42 @@ static LogicalResult verifyScatterOperandAndResultShape( auto deviceGroupSize = DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); auto operandScatterDimSize = - DimensionSize(operandType.getDimSize(scatterAxis)); + DimensionSize(operandType.getDimSize(tensorAxis)); if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() && int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) { return emitError(result.getLoc()) << "Operand dimension size " << int64_t(operandScatterDimSize) << " is not divisible by collective device group size " - << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis + << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis << "."; } - DimensionSize expectedResultScatterDimSize = + DimensionSize expectedResultTensorDimSize = operandScatterDimSize / deviceGroupSize; if (failed(verifyDimensionCompatibility( - result.getLoc(), expectedResultScatterDimSize.value(), - resultType.getDimSize(scatterAxis), scatterAxis))) { + result.getLoc(), expectedResultTensorDimSize.value(), + resultType.getDimSize(tensorAxis), tensorAxis))) { return failure(); } return success(); } +static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, + ArrayRef meshAxes, + int64_t sliceAxis) { + RankedTensorType operandRankedTensorType = + cast(operandType); + DimensionSize operandSliceAxisSize = + operandRankedTensorType.getShape()[sliceAxis]; + SmallVector resultShape = + llvm::to_vector(operandRankedTensorType.getShape()); + + resultShape[sliceAxis] = + operandSliceAxisSize / + DimensionSize(collectiveProcessGroupSize(meshAxes, mesh)); + return operandRankedTensorType.clone(resultShape); +} + //===----------------------------------------------------------------------===// // mesh.all_gather op //===----------------------------------------------------------------------===// @@ -625,6 +647,40 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add>(context); } +//===----------------------------------------------------------------------===// +// mesh.all_slice op +//===----------------------------------------------------------------------===// + +LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto mesh = getMeshAndVerifyAxes(*this, symbolTable); + if (failed(mesh)) { + return failure(); + } + return verifyScatterOrSliceOperandAndResultShape( + getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(), + mesh.value().getShape()); +} + +void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add>(context); +} + +void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, + Value input, MeshOp mesh, ArrayRef meshAxes, + int64_t sliceAxis) { + Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis); + build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes, + sliceAxis); +} + +void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, + Type resultType, Value input, StringRef mesh, + ArrayRef meshAxes, int64_t sliceAxis) { + build(odsBuilder, odsState, resultType, mesh, meshAxes, input, + APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis)); +} + //===----------------------------------------------------------------------===// // mesh.all_to_all op //===----------------------------------------------------------------------===// @@ -752,7 +808,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return failure(); } - return verifyScatterOperandAndResultShape( + return verifyScatterOrSliceOperandAndResultShape( getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape()); } @@ -778,9 +834,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } auto scatterAxis = getScatterAxis().getSExtValue(); - return verifyScatterOperandAndResultShape(getInput(), getResult(), - scatterAxis, getMeshAxes(), - mesh.value().getShape()); + return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(), + scatterAxis, getMeshAxes(), + mesh.value().getShape()); } void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt index dccb75848c94f..28af820440076 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt @@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms LINK_LIBS PUBLIC MLIRAffineDialect + MLIRAffineUtils MLIRArithDialect + MLIRArithUtils MLIRControlFlowDialect MLIRFuncDialect MLIRIR diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp index c0273cdaef714..7fcac2312444f 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Mesh/Transforms/Simplifications.h" +#include "TransformsDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -16,7 +17,6 @@ #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include #include #include @@ -56,13 +56,10 @@ namespace { // symbol tables. // We can't use DialectFoldInterface since the cache may be invalidated by some // pass changing the referenced MeshOp ops. -struct MeshShapeFolder : OpRewritePattern { - template - MeshShapeFolder(SymbolTableCollection &symbolTableCollection, - OpRewritePatternArgs &&...opRewritePatternArgs) - : OpRewritePattern( - std::forward(opRewritePatternArgs)...), - symbolTableCollection(symbolTableCollection) {} +struct MeshShapeFolder + : OpRewritePatternWithSymbolTableCollection { + using OpRewritePatternWithSymbolTableCollection:: + OpRewritePatternWithSymbolTableCollection; LogicalResult matchAndRewrite(MeshShapeOp op, PatternRewriter &rewriter) const override { ImplicitLocOpBuilder builder(op->getLoc(), rewriter); @@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern { return success(); } - -private: - SymbolTableCollection &symbolTableCollection; }; } // namespace diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index b649157a9e46d..7cbe0de048769 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -8,9 +8,6 @@ #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" @@ -128,92 +125,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } -static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape, - int64_t splitTensorAxis, - int64_t splitCount) { - SmallVector targetShape = llvm::to_vector(sourceShape.getShape()); - targetShape[splitTensorAxis] = - shardDimension(targetShape[splitTensorAxis], splitCount); - return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); -} - // Split a replicated tensor along a mesh axis. // e.g. [[0, 1]] -> [[0, 1, 2]]. // Returns the spmdized target value with its sharding. -// -// The implementation is the extract the tensor slice corresponding -// to the current device. static std::tuple, MeshShardingAttr> splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { - MLIRContext *ctx = builder.getContext(); - builder.setInsertionPointAfterValue(sourceShard); - - Value zero = builder.create(builder.getIndexAttr(0)); - - Value processIndexAlongAxis = + TypedValue targetShard = builder - .create(mesh.getSymName(), - SmallVector({splitMeshAxis})) - .getResult()[0]; - + .create(sourceShard, mesh, + ArrayRef(splitMeshAxis), + splitTensorAxis) + .getResult() + .cast>(); MeshShardingAttr targetSharding = targetShardingInSplitLastAxis( - ctx, sourceSharding, splitTensorAxis, splitMeshAxis); - ShapedType targetShape = targetShapeInSplitLastAxis( - sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]); - - Value meshAxisSize = - builder - .create(mesh.getSymName(), - SmallVector({splitMeshAxis})) - .getResult()[0]; - - Value sourceAxisSize = - builder.create(sourceShard, splitTensorAxis); - Value sourceAxisSizeModMeshAxisSize = - builder.create(sourceAxisSize, meshAxisSize); - Value isTargetShapeExactlyDivisible = builder.create( - arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero); - builder.create( - isTargetShapeExactlyDivisible, - "Sharding a tensor with axis size that is not exactly divisible by the " - "mesh axis size is not supported."); - Value targetAxisSize = - builder.create(sourceAxisSize, meshAxisSize); - Value axisOffset = - builder.create(targetAxisSize, processIndexAlongAxis); - SmallVector staticOffsets(targetShape.getRank(), 0); - staticOffsets[splitTensorAxis] = ShapedType::kDynamic; - DenseI64ArrayAttr staticOffsetsAttr = - DenseI64ArrayAttr::get(ctx, staticOffsets); - SmallVector dynamicOffsets(1, axisOffset); - - DenseI64ArrayAttr staticSizesAttr = - DenseI64ArrayAttr::get(ctx, targetShape.getShape()); - SmallVector dynamicSizes; - for (int64_t i = 0; i < targetShape.getRank(); ++i) { - if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) { - if (i == splitTensorAxis) { - dynamicSizes.push_back(targetAxisSize); - } else { - Value dimSize = builder.create(sourceShard, i); - dynamicSizes.push_back(dimSize); - } - } - } - - DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get( - ctx, SmallVector(targetShape.getRank(), 1)); - TypedValue targetShard = - builder - .create( - targetShape, sourceShard, dynamicOffsets, dynamicSizes, - SmallVector({}), staticOffsetsAttr, staticSizesAttr, - staticStridesAttr) - .getResult(); - return {targetShard.cast>(), targetSharding}; + builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); + return {targetShard, targetSharding}; } // Detect if the resharding is of type e.g. @@ -587,8 +516,7 @@ TypedValue reshard(OpBuilder &builder, ShardOp source, } void reshardingRegisterDependentDialects(DialectRegistry ®istry) { - registry.insert(); + registry.insert(); } #define GEN_PASS_DEF_SPMDIZATION diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index 03b1d9b349802..d59b9119dea54 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -7,12 +7,21 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "TransformsDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "llvm/ADT/STLExtras.h" @@ -26,18 +35,14 @@ namespace { /// Lower `mesh.process_multi_index` into expression using /// `mesh.process_linear_index` and `mesh.mesh_shape`. -struct ProcessMultiIndexOpLowering : OpRewritePattern { - template - ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection, - OpRewritePatternArgs &&...opRewritePatternArgs) - : OpRewritePattern( - std::forward(opRewritePatternArgs)...), - symbolTableCollection(symbolTableCollection) {} +struct ProcessMultiIndexOpLowering + : OpRewritePatternWithSymbolTableCollection { + using OpRewritePatternWithSymbolTableCollection:: + OpRewritePatternWithSymbolTableCollection; LogicalResult matchAndRewrite(ProcessMultiIndexOp op, PatternRewriter &rewriter) const override { - MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom( - op.getOperation(), op.getMeshAttr()); + MeshOp mesh = getMesh(op, symbolTableCollection); if (!mesh) { return failure(); } @@ -64,21 +69,143 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern { rewriter.replaceAllUsesWith(op.getResults(), multiIndex); return success(); } +}; + +struct AllSliceOpLowering + : OpRewritePatternWithSymbolTableCollection { + using OpRewritePatternWithSymbolTableCollection:: + OpRewritePatternWithSymbolTableCollection; + + LogicalResult matchAndRewrite(AllSliceOp op, + PatternRewriter &rewriter) const override { + // 1. Compute the process linear index inside the process group from its + // multi-index. + // + // 2. Extract a slice from the input tensor. + // All axes except the slicing axis are not interesting and take the full + // axis. + // The slice axis is split into equisized parts with count + // the number of processes in the collective process group induced by + // the mesh axes. + // The part for each process is determined by the corresponding + // linear-index in the process group. + // + // There are no collectives that require communication. + // Each process operates on its local tensor. + + MeshOp mesh = getMesh(op, symbolTableCollection); + if (!mesh) { + return failure(); + } -private: - SymbolTableCollection &symbolTableCollection; + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + builder.setInsertionPointAfter(op.getOperation()); + + Value zero = builder.create(builder.getIndexAttr(0)); + + Operation::result_range processInGroupMultiIndex = + builder.create(mesh.getSymName(), op.getMeshAxes()) + .getResults(); + + Operation::result_range processGroupShape = + builder.create(mesh.getSymName(), op.getMeshAxes()) + .getResult(); + Value processGroupSize = + createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); + + int64_t sliceAxis = op.getSliceAxis().getSExtValue(); + Value operandSliceAxisSize = + builder.create(op.getOperand(), sliceAxis); + Value operandSliceAxisSizeModProcessGroupSize = + builder.create(operandSliceAxisSize, processGroupSize); + Value isTargetShapeExactlyDivisible = builder.create( + arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize, + zero); + builder.create(isTargetShapeExactlyDivisible, + "Slicing a tensor with axis size that is " + "not exactly divisible by the " + "mesh process group size is not supported."); + Value resultSliceAxisSize = + builder.create(operandSliceAxisSize, processGroupSize); + OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( + llvm::to_vector_of(processInGroupMultiIndex), + llvm::to_vector_of(processGroupShape), builder); + + // insert tensor.extract_slice + RankedTensorType operandType = + op.getOperand().getType().cast(); + SmallVector sizes; + for (int64_t i = 0; i < operandType.getRank(); ++i) { + if (i == sliceAxis) { + sizes.emplace_back(resultSliceAxisSize); + } else { + Value dimSize = builder.create(op.getOperand(), i); + sizes.emplace_back(dimSize); + } + } + SmallVector offsets( + operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0)); + offsets[sliceAxis] = + ArithBuilder(builder, builder.getLoc()) + .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(), + processInGroupLinearIndex), + resultSliceAxisSize); + SmallVector strides( + operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1)); + Value slice = builder.create( + op.getOperand(), offsets, sizes, strides); + Value newResult = + builder.create(op.getResult().getType(), slice); + rewriter.replaceAllUsesWith(op.getResult(), newResult); + + return success(); + } }; } // namespace -void processMultiIndexOpLoweringPopulatePatterns( +void populateProcessMultiIndexOpLoweringPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { patterns.add(symbolTableCollection, patterns.getContext()); } -void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry) { +void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) { registry.insert(); } +void populateAllSliceOpLoweringPatterns( + RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { + patterns.add(symbolTableCollection, + patterns.getContext()); +} + +void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) { + registry.insert(); +} + +void populateAllOpLoweringPatterns( + RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { + populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection); + populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); +} + +void registerAllOpLoweringDialects(DialectRegistry ®istry) { + registerProcessMultiIndexOpLoweringDialects(registry); + registerAllSliceOpLoweringDialects(registry); +} + +TypedValue +createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, + ImplicitLocOpBuilder &builder) { + Operation::result_range meshShape = + builder.create(mesh, axes).getResults(); + return arith::createProduct(builder, builder.getLoc(), + llvm::to_vector_of(meshShape), + builder.getIndexType()) + .cast>(); +} + } // namespace mlir::mesh diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h new file mode 100644 index 0000000000000..3e3f584caca24 --- /dev/null +++ b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h @@ -0,0 +1,35 @@ +//===- TransformsDetail.h - -------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H +#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +namespace mesh { + +template +struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern { + template + OpRewritePatternWithSymbolTableCollection( + SymbolTableCollection &symbolTableCollection, + OpRewritePatternArgs &&...opRewritePatternArgs) + : OpRewritePattern( + std::forward(opRewritePatternArgs)...), + symbolTableCollection(symbolTableCollection) {} + +protected: + SymbolTableCollection &symbolTableCollection; +}; + +} // namespace mesh +} // namespace mlir + +#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 2e42c4e870716..18ca3c332e020 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -346,6 +346,24 @@ TypedAttr Builder::getZeroAttr(Type type) { return {}; } +TypedAttr Builder::getOneAttr(Type type) { + if (llvm::isa(type)) + return getFloatAttr(type, 1.0); + if (llvm::isa(type)) + return getIndexAttr(1); + if (llvm::dyn_cast(type)) + return getIntegerAttr(type, + APInt(llvm::cast(type).getWidth(), 1)); + if (llvm::isa(type)) { + auto vtType = llvm::cast(type); + auto element = getOneAttr(vtType.getElementType()); + if (!element) + return {}; + return DenseElementsAttr::get(vtType, element); + } + return {}; +} + //===----------------------------------------------------------------------===// // Affine Expressions, Affine Maps, and Integer Sets. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir new file mode 100644 index 0000000000000..4f54607a1c7ff --- /dev/null +++ b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s + +mesh.mesh @mesh_1d(shape = ?) + +// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh +func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh( + // CHECK: %[[ARG:.*]]: tensor + %arg0: tensor +// CHECK-SAME: -> tensor { +) -> tensor { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor + // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index + // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index + // CHECK: cf.assert %[[AXIS_SIZE_CHECK]] + // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index + // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index + // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor to tensor + %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor -> tensor + // CHECK: return %[[RESULT]] : tensor + return %0 : tensor +} + +// ----- + +mesh.mesh @mesh_1d(shape = 2) + +// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh +func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh( + // CHECK: %[[ARG:.*]]: tensor<2xf16> + %arg0: tensor<2xf16> +// CHECK-SAME: -> tensor<1xf16> { +) -> tensor<1xf16> { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor + // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor to tensor<1xf16> + %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16> + // CHECK: return %[[RESULT]] : tensor<1xf16> + return %0 : tensor<1xf16> +} + +// ----- + +// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)> + +mesh.mesh @mesh_4d(shape = ?x?x?x?) + +// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh +func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh( + // CHECK: %[[ARG:.*]]: tensor + %arg0 : tensor +// CHECK-SAME: -> tensor { +) -> tensor { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index + // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index + // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index + // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor + // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index + // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index + // CHECK: cf.assert %[[AXIS_SIZE_CHECK]] + // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index + // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1] + // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor + // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index + // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor to tensor + %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor -> tensor + // CHECK: return %[[RESULT]] : tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir index 23c5b253b4c07..633324ae680eb 100644 --- a/mlir/test/Dialect/Mesh/canonicalization.mlir +++ b/mlir/test/Dialect/Mesh/canonicalization.mlir @@ -63,6 +63,19 @@ func.func @all_gather_empty_mesh_axes( return %0 : tensor<4xf32> } +// CHECK-LABEL: func @all_slice_empty_mesh_axes +func.func @all_slice_empty_mesh_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: mesh.scatter + %0 = mesh.all_slice %arg0 on @mesh0 + mesh_axes = [] + slice_axis = 0 + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + // CHECK-LABEL: func @broadcast_empty_mesh_axes func.func @broadcast_empty_mesh_axes( // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir index 3fa3ebd67b15e..6d7df86d78406 100644 --- a/mlir/test/Dialect/Mesh/invalid.mlir +++ b/mlir/test/Dialect/Mesh/invalid.mlir @@ -316,6 +316,58 @@ func.func @all_gather_invalid_negative_gather_axis( // ----- +mesh.mesh @mesh0(shape = 3) + +func.func @all_slice_duplicate_mesh_axis( + %arg0 : tensor) -> tensor { + // expected-error@+1 {{Mesh axes contains duplicate elements.}} + %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0] + slice_axis = 0 + : tensor -> tensor + return %0 : tensor +} + +// ----- + +mesh.mesh @mesh0(shape = 3) + +func.func @all_slice_invalid_dynamic_dimension( + %arg0 : tensor) -> tensor<2xf32> { + // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} + %0 = mesh.all_slice %arg0 on @mesh0 + slice_axis = 0 + : tensor -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +mesh.mesh @mesh0(shape = 3) + +func.func @all_slice_invalid_static_dimension_size( + %arg0 : tensor<3xf32>) -> tensor<2xf32> { + // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} + %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] + slice_axis = 0 + : tensor<3xf32> -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +mesh.mesh @mesh0(shape = 3) + +func.func @all_slice_invalid_operand_static_dimension_size( + %arg0 : tensor<4xf32>) -> tensor { + // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} + %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] + slice_axis = 0 + : tensor<4xf32> -> tensor + return %0 : tensor +} + +// ----- + func.func @all_to_all_invalid_mesh_symbol( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} @@ -660,7 +712,7 @@ mesh.mesh @mesh0(shape = 3) func.func @reduce_scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor { - // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}} + // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 : tensor<4xf32> -> tensor return %0 : tensor @@ -711,7 +763,7 @@ mesh.mesh @mesh0(shape = 3) func.func @scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor { - // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}} + // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 root = [1] : (tensor<4xf32>) -> tensor diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 40a8469b26464..6e5df86b13106 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -208,6 +208,30 @@ func.func @all_gather_dynamic_dims_in_mesh( return %0 : tensor<5x?xf32> } +// CHECK-LABEL: func @all_slice_static_dimensions +func.func @all_slice_static_dimensions( + // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> + %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { + // CHECK-NEXT: mesh.all_slice %[[ARG]] + // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1 + // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32> + %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1 + : tensor<3x4xf32> -> tensor<3x1xf32> + return %0 : tensor<3x1xf32> +} + +// CHECK-LABEL: func @all_slice_dynamic_dimensions +func.func @all_slice_dynamic_dimensions( + // CHECK-SAME: %[[ARG:.*]]: tensor + %arg0 : tensor) -> tensor { + // CHECK-NEXT: mesh.all_slice %[[ARG]] + // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0 + // CHECK-SAME: : tensor -> tensor + %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0 + : tensor -> tensor + return %0 : tensor +} + // CHECK-LABEL: func @all_to_all func.func @all_to_all( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir index f1d686135c28e..ba05306598bcc 100644 --- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir +++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir @@ -19,17 +19,9 @@ func.func @split_replicated_tensor_axis( // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32> %arg0: tensor<3x14xf32> ) -> tensor<3x14xf32> { - // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index - // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index - // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index - // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index - // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index - // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]] - // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index - // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index - // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32> - // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> + // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1 + // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32> + // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32> %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32> // CHECK: return %[[RESULT]] : tensor<3x14xf32> @@ -41,22 +33,11 @@ func.func @split_replicated_tensor_axis_dynamic( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0: tensor ) -> tensor { - // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index - // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index - // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index - // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor - // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index - // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index - // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]] - // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index - // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index - // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor - // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0] - // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor to tensor + // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor -> tensor %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor - // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor + // CHECK: return %[[RESULT]] : tensor return %1 : tensor } diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index 9993c1518e9ea..2fb8029dfe64a 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -63,9 +63,8 @@ func.func @unary_elementwise_with_resharding( %arg0: tensor<2xi8> // CHECK-SAME: -> tensor<2xi8> { ) -> tensor<2xi8> { - // We don't care about the whole resharding IR, just that it happens. - // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1] - // CHECK-SAME: tensor<2xi8> to tensor<1xi8> + // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8> %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8> @@ -109,9 +108,8 @@ func.func @multiple_chained_ops( %arg0: tensor<2xi8> // CHECK-SAME: -> tensor<1xi8> { ) -> tensor<2xi8> { - // We don't care about the whole resharding IR, just that it happens. - // CHECK: %[[RESHARD1:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1] - // CHECK-SAME: tensor<2xi8> to tensor<1xi8> + // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8> %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8> @@ -122,8 +120,8 @@ func.func @multiple_chained_ops( %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8> // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8> %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RESHARD3:.*]] = tensor.extract_slice %[[ABS2]][%{{.*}}] [1] [1] - // CHECK-SAME: tensor<2xi8> to tensor<1xi8> + // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8> %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> // CHECK: return %[[RESHARD3]] : tensor<1xi8> diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt index 00931e6c94fc5..07e9bb6f9f238 100644 --- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt @@ -1,6 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMeshTest - TestProcessMultiIndexOpLowering.cpp + TestOpLowering.cpp TestReshardingSpmdization.cpp TestSimplifications.cpp diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp new file mode 100644 index 0000000000000..321b6a42bf966 --- /dev/null +++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp @@ -0,0 +1,79 @@ +//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TestAllSliceOpLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass) + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + SymbolTableCollection symbolTableCollection; + mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); + LogicalResult status = + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)status; + assert(succeeded(status) && "applyPatternsAndFoldGreedily failed."); + } + void getDependentDialects(DialectRegistry ®istry) const override { + mesh::registerAllSliceOpLoweringDialects(registry); + } + StringRef getArgument() const final { + return "test-mesh-all-slice-op-lowering"; + } + StringRef getDescription() const final { + return "Test lowering of all-slice."; + } +}; + +struct TestMultiIndexOpLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass) + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + SymbolTableCollection symbolTableCollection; + mesh::populateProcessMultiIndexOpLoweringPatterns(patterns, + symbolTableCollection); + LogicalResult status = + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)status; + assert(succeeded(status) && "applyPatternsAndFoldGreedily failed."); + } + void getDependentDialects(DialectRegistry ®istry) const override { + mesh::registerProcessMultiIndexOpLoweringDialects(registry); + } + StringRef getArgument() const final { + return "test-mesh-process-multi-index-op-lowering"; + } + StringRef getDescription() const final { + return "Test lowering of mesh.process_multi_index op."; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestOpLoweringPasses() { + PassRegistration(); + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp deleted file mode 100644 index 0bcc403a2734e..0000000000000 --- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp +++ /dev/null @@ -1,54 +0,0 @@ -//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -struct TestMultiIndexOpLoweringPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass) - - void runOnOperation() override; - void getDependentDialects(DialectRegistry ®istry) const override { - mesh::processMultiIndexOpLoweringRegisterDialects(registry); - } - StringRef getArgument() const final { - return "test-mesh-process-multi-index-op-lowering"; - } - StringRef getDescription() const final { - return "Test lowering of mesh.process_multi_index op."; - } -}; -} // namespace - -void TestMultiIndexOpLoweringPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - SymbolTableCollection symbolTableCollection; - mesh::processMultiIndexOpLoweringPopulatePatterns(patterns, - symbolTableCollection); - LogicalResult status = - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - (void)status; - assert(succeeded(status) && "applyPatternsAndFoldGreedily failed."); -} - -namespace mlir { -namespace test { -void registerTestMultiIndexOpLoweringPass() { - PassRegistration(); -} -} // namespace test -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index cec1e5225d5a6..f11c6b4355fdd 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -119,7 +119,7 @@ void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); void registerTestMeshSimplificationsPass(); void registerTestMeshReshardingSpmdizationPass(); -void registerTestMultiIndexOpLoweringPass(); +void registerTestOpLoweringPasses(); void registerTestNextAccessPass(); void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); @@ -241,7 +241,7 @@ void registerTestPasses() { mlir::test::registerTestMathToVCIXPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); - mlir::test::registerTestMultiIndexOpLoweringPass(); + mlir::test::registerTestOpLoweringPasses(); mlir::test::registerTestMeshSimplificationsPass(); mlir::test::registerTestMeshReshardingSpmdizationPass(); mlir::test::registerTestNextAccessPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 54c9f19c6ab1e..0e32864429e40 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3384,7 +3384,9 @@ cc_library( includes = ["include"], deps = [ ":AffineDialect", + ":AffineUtils", ":ArithDialect", + ":ArithUtils", ":ControlFlowDialect", ":DialectUtils", ":FuncDialect",