diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index bd9ab3545ccb0..a6de3e9597d72 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -210,6 +210,25 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", flexibility allows to progressively drop unit dimensions while lowering between different flavors of ops on that operate on tensors. + Verification vs Inference in the rank-reduced case: + =================================================== + Note that there may be multiple ways to infer a resulting rank-reduced type. + e.g. 1x6x1 could potentially rank-reduce to either 1x6 or 6x1 2-D shapes. + + To disambiguate, the inference helpers `inferCanonicalRankReducedResultType` + only drop the first unit dimensions, in order: + e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6. + + Verification however has access to result type and does not need to infer. + The verifier calls `isRankReducedType(getSource(), getResult())` to + determine whether the result type is rank-reduced from the source type. + This computes a so-called rank-reduction mask, consisting of dropped unit + dims, to map the rank-reduced type to the source type by dropping ones: + e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2} + 6x1 is a rank-reduced version of 1x6x1 by mask {0} + 1x2x1x4 is a rank-reduced version of 1x1x2x1x1x4x1 by mask {1, 4, 6} + (remaining common 1 dimensions are matched eagerly) + Example: ``` @@ -274,26 +293,43 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", return getResult().getType().cast(); } - /// An extract_slice result type can be fully inferred from the source type - /// and the static representation of offsets, sizes and strides. Special - /// sentinels encode the dynamic case. + /// Compute the rank-reduction mask that can be applied to map the source + /// tensor type to the result tensor type by dropping unit dims. + llvm::Optional> + computeRankReductionMask() { + return ::mlir::computeRankReductionMask(getSourceType().getShape(), + getType().getShape()); + }; + + /// An extract_slice result type can be inferred, when it is not + /// rank-reduced, from the source type and the static representation of + /// offsets, sizes and strides. Special sentinels encode the dynamic case. static RankedTensorType inferResultType( - RankedTensorType sourceRankedTensorType, + ShapedType sourceShapedTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); static RankedTensorType inferResultType( - RankedTensorType sourceRankedTensorType, + ShapedType sourceShapedTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); - static RankedTensorType inferRankReducedResultType( + + /// If the rank is reduced (i.e. the desiredResultRank is smaller than the + /// number of sizes), drop as many size 1 as needed to produce an inferred type + /// with the desired rank. + /// + /// Note that there may be multiple ways to compute this rank-reduced type: + /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors. + /// + /// To disambiguate, this function always drops the first 1 sizes occurrences. + static RankedTensorType inferCanonicalRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); - static RankedTensorType inferRankReducedResultType( + static RankedTensorType inferCanonicalRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, ArrayRef staticOffsets, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp index 10760685301fd..6c6bcabb499d9 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp @@ -228,7 +228,7 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep( return b.create(loc, target, dim).getResult(); return b.getIndexAttr(shapedType.getDimSize(dim)); }); - auto t = tensor::ExtractSliceOp::inferRankReducedResultType( + auto t = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( insertOp.getSourceType().getRank(), insertOp.getDest().getType().cast(), mixedOffsets, mixedSizes, mixedStrides); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 970b628a15c4b..e1e7ed76d23cc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -499,10 +499,11 @@ struct UseRankReducedExtractSliceOp if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); - auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType( - reassociation->size(), sliceOp.getSourceType(), - offsets, sizes, strides) - .cast(); + auto rankReducedType = + tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( + reassociation->size(), sliceOp.getSourceType(), offsets, sizes, + strides) + .cast(); Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 897af9fcee6f3..305e8f7e42394 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -957,25 +957,24 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { // ExtractSliceOp //===----------------------------------------------------------------------===// -/// An extract_slice op result type can be fully inferred from the source type -/// and the static representation of offsets, sizes and strides. Special -/// sentinels encode the dynamic case. +/// An extract_slice result type can be inferred, when it is not +/// rank-reduced, from the source type and the static representation of +/// offsets, sizes and strides. Special sentinels encode the dynamic case. RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceRankedTensorType, ArrayRef staticOffsets, + ShapedType sourceShapedTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type and // strides=1. - unsigned rank = sourceRankedTensorType.getRank(); - (void)rank; - assert(staticSizes.size() == rank && + assert(static_cast(staticSizes.size()) == + sourceShapedTensorType.getRank() && "unexpected staticSizes not equal to rank of source"); return RankedTensorType::get(staticSizes, - sourceRankedTensorType.getElementType()); + sourceShapedTensorType.getElementType()); } RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceRankedTensorType, ArrayRef offsets, + ShapedType sourceShapedTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; @@ -985,26 +984,33 @@ RankedTensorType ExtractSliceOp::inferResultType( ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); - return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, + return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets, staticSizes, staticStrides); } -/// An extract_slice op result type can be fully inferred from the source type -/// and the static representation of offsets, sizes and strides. Special -/// sentinels encode the dynamic case. -RankedTensorType ExtractSliceOp::inferRankReducedResultType( - unsigned resultRank, RankedTensorType sourceRankedTensorType, +/// If the rank is reduced (i.e. the desiredResultRank is smaller than the +/// number of sizes), drop as many size 1 as needed to produce an inferred type +/// with the desired rank. +/// +/// Note that there may be multiple ways to compute this rank-reduced type: +/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors. +/// +/// To disambiguate, this function always drops the first 1 sizes occurrences. +RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( + unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { + // Type inferred in the absence of rank-reducing behavior. auto inferredType = inferResultType(sourceRankedTensorType, offsets, sizes, strides) .cast(); - int rankDiff = inferredType.getRank() - resultRank; + int rankDiff = inferredType.getRank() - desiredResultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); llvm::SmallBitVector dimsToProject = getPositionsOfShapeOne(rankDiff, shape); SmallVector projectedShape; + // Best effort rank-reducing: drop 1s in order. for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) if (!dimsToProject.test(pos)) projectedShape.push_back(shape[pos]); @@ -1014,8 +1020,8 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType( return inferredType; } -RankedTensorType ExtractSliceOp::inferRankReducedResultType( - unsigned resultRank, RankedTensorType sourceRankedTensorType, +RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( + unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { SmallVector staticOffsets, staticSizes, staticStrides; @@ -1026,8 +1032,8 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType( ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); - return ExtractSliceOp::inferRankReducedResultType( - resultRank, sourceRankedTensorType, staticOffsets, staticSizes, + return ExtractSliceOp::inferCanonicalRankReducedResultType( + desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); } @@ -1123,26 +1129,6 @@ LogicalResult ExtractSliceOp::verify() { return produceSliceErrorMsg(result, *this, expectedType); } -/// Infer the canonical type of the result of an extract_slice op. Returns a -/// type with rank `resultRank` that is either the rank of the rank-reduced -/// type, or the non-rank-reduced type. -static RankedTensorType -getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType, - ArrayRef mixedOffsets, - ArrayRef mixedSizes, - ArrayRef mixedStrides) { - auto resultType = - ExtractSliceOp::inferRankReducedResultType( - resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) - .cast(); - if (resultType.getRank() != resultRank) { - resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets, - mixedSizes, mixedStrides) - .cast(); - } - return resultType; -} - llvm::SmallBitVector ExtractSliceOp::getDroppedDims() { ArrayRef resultShape = getType().getShape(); SmallVector mixedSizes = getMixedSizes(); @@ -1205,7 +1191,7 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { - // Any constant operand, just return to let SubViewOpConstantFolder kick in. + // Any constant operand, just return to let the constant folder kick in. if (llvm::any_of(sliceOp.getOperands(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); })) @@ -1219,10 +1205,11 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern { return failure(); /// Deduce the type of the result to use for the canonicalized operation. - RankedTensorType resultType = getCanonicalSliceResultType( - sliceOp.getType().getRank(), sliceOp.getSourceType(), - sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), - sliceOp.getMixedStrides()); + RankedTensorType resultType = + ExtractSliceOp::inferCanonicalRankReducedResultType( + sliceOp.getType().getRank(), sliceOp.getSourceType(), + sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), + sliceOp.getMixedStrides()); Value newSlice = rewriter.create( sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), @@ -1366,9 +1353,9 @@ struct SliceReturnTypeCanonicalizer { ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { - return getCanonicalSliceResultType(op.getType().getRank(), - op.getSourceType(), mixedOffsets, - mixedSizes, mixedStrides); + return ExtractSliceOp::inferCanonicalRankReducedResultType( + op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes, + mixedStrides); } }; @@ -1506,9 +1493,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayAttr staticStrides, ShapedType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type inference. - auto expected = ExtractSliceOp::inferRankReducedResultType( - srcType.getRank(), dstType.cast(), - extractFromI64ArrayAttr(staticOffsets), + auto expected = ExtractSliceOp::inferResultType( + dstType, extractFromI64ArrayAttr(staticOffsets), extractFromI64ArrayAttr(staticSizes), extractFromI64ArrayAttr(staticStrides)) .cast(); @@ -1600,7 +1586,7 @@ class InsertSliceOpConstantArgumentFolder final canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - auto sourceType = ExtractSliceOp::inferRankReducedResultType( + auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(), mixedOffsets, mixedSizes, mixedStrides); Value toInsert = insertSliceOp.getSource();