diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 2453cf5b5b5a4..ca2464f6272d3 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -486,17 +486,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", /// An extract_slice result type can be inferred, when it is not /// rank-reduced, from the source type and the static representation of - /// offsets, sizes and strides. Special sentinels encode the dynamic case. + /// sizes. Special sentinels encode the dynamic case. static RankedTensorType inferResultType( RankedTensorType sourceTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + ArrayRef staticSizes); static RankedTensorType inferResultType( RankedTensorType sourceTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + ArrayRef staticSizes); /// If the rank is reduced (i.e. the desiredResultRank is smaller than the /// number of sizes), drop as many size 1 as needed to produce an inferred type @@ -509,15 +505,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", static RankedTensorType inferCanonicalRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + ArrayRef staticSizes); static RankedTensorType inferCanonicalRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + ArrayRef staticSizes); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 22690daa4f9e1..9e6c1e6036cba 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp SmallVector sizes = sliceOp.getMixedSizes(); auto rankReducedType = cast( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - reassociation->size(), sliceOp.getSourceType(), offsets, sizes, - strides)); + reassociation->size(), sliceOp.getSourceType(), sizes)); Location loc = sliceOp.getLoc(); Value newSlice = tensor::ExtractSliceOp::create( diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 110bfdce72ea4..125db6249b23d 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2293,9 +2293,9 @@ void ExtractSliceOp::getAsmResultNames( /// An extract_slice result type can be inferred, when it is not /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef staticOffsets, - ArrayRef staticSizes, ArrayRef staticStrides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef staticSizes) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type // and strides=1. @@ -2307,11 +2307,12 @@ RankedTensorType ExtractSliceOp::inferResultType( } // TODO: This uses neither offsets nor strides! -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef offsets, - ArrayRef sizes, ArrayRef strides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef sizes) { SmallVector staticSizes; std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes); + assert(static_cast(staticSizes.size()) == sourceTensorType.getRank() && "unexpected staticSizes not equal to rank of source"); @@ -2329,11 +2330,10 @@ RankedTensorType ExtractSliceOp::inferResultType( /// To disambiguate, this function always drops the first 1 sizes occurrences. RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides) { + ArrayRef sizes) { // Type inferred in the absence of rank-reducing behavior. auto inferredType = llvm::cast( - inferResultType(sourceRankedTensorType, offsets, sizes, strides)); + inferResultType(sourceRankedTensorType, sizes)); int rankDiff = inferredType.getRank() - desiredResultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); @@ -2352,16 +2352,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides) { - SmallVector staticOffsets, staticSizes, staticStrides; - SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + ArrayRef sizes) { + SmallVector staticSizes; + SmallVector dynamicSizes; dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); return ExtractSliceOp::inferCanonicalRankReducedResultType( - desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, - staticStrides); + desiredResultRank, sourceRankedTensorType, staticSizes); } /// Build an ExtractSliceOp with mixed static and dynamic entries and custom @@ -2380,8 +2376,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, auto sourceRankedTensorType = llvm::cast(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = llvm::cast(ExtractSliceOp::inferResultType( - sourceRankedTensorType, staticOffsets, staticSizes, staticStrides)); + resultType = llvm::cast( + ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes)); } result.addAttributes(attrs); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, @@ -2456,8 +2452,8 @@ LogicalResult ExtractSliceOp::verify() { RankedTensorType sourceType = getSourceType(); // Verify result type against inferred type. - RankedTensorType expectedType = ExtractSliceOp::inferResultType( - sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides()); + RankedTensorType expectedType = + ExtractSliceOp::inferResultType(sourceType, getMixedSizes()); SliceVerificationResult result = isRankReducedType(expectedType, getType()); if (result != SliceVerificationResult::Success) return produceSliceErrorMsg(result, *this, expectedType); @@ -2697,8 +2693,7 @@ struct SliceReturnTypeCanonicalizer { ArrayRef mixedSizes, ArrayRef mixedStrides) { return ExtractSliceOp::inferCanonicalRankReducedResultType( - op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes, - mixedStrides); + op.getType().getRank(), op.getSourceType(), mixedSizes); } }; @@ -2839,8 +2834,8 @@ static SliceVerificationResult verifyInsertSliceOp( ArrayRef staticStrides, RankedTensorType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type // inference. - RankedTensorType expected = ExtractSliceOp::inferResultType( - dstType, staticOffsets, staticSizes, staticStrides); + RankedTensorType expected = + ExtractSliceOp::inferResultType(dstType, staticSizes); if (expectedType) *expectedType = expected; return isRankReducedType(expected, srcType); @@ -2968,7 +2963,7 @@ class InsertSliceOpConstantArgumentFolder final // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), - mixedOffsets, mixedSizes, mixedStrides); + mixedSizes); Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 7ec61c7df81cf..421f9ab7ceff7 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract // supported. Moreover, only simple cases where the resulting ExtractSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( - srcType, extractSliceOp.getStaticOffsets(), - extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); + srcType, extractSliceOp.getStaticSizes()); if (nonReducingExtractType != resultType) return failure();