-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][tensor] Simplify ExtractSliceOp::inferResultType (nfc) #169313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The `offsets` and `strides` arguments are neither used nor required - removed them and simplify this hook.
|
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/169313.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..ca2464f6272d3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -486,17 +486,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
- /// offsets, sizes and strides. Special sentinels encode the dynamic case.
+ /// sizes. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
/// number of sizes), drop as many size 1 as needed to produce an inferred type
@@ -509,15 +505,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690daa4f9e1..9e6c1e6036cba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides));
+ reassociation->size(), sliceOp.getSourceType(), sizes));
Location loc = sliceOp.getLoc();
Value newSlice = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 110bfdce72ea4..125db6249b23d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2293,9 +2293,9 @@ void ExtractSliceOp::getAsmResultNames(
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<int64_t> staticSizes) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
@@ -2307,11 +2307,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
}
// TODO: This uses neither offsets nor strides!
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
+
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
@@ -2329,11 +2330,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> sizes) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+ inferResultType(sourceRankedTensorType, sizes));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
@@ -2352,16 +2352,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
- SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<int64_t> staticSizes;
+ SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
- staticStrides);
+ desiredResultRank, sourceRankedTensorType, staticSizes);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2380,8 +2376,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
- sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
+ resultType = llvm::cast<RankedTensorType>(
+ ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2456,8 +2452,8 @@ LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
// Verify result type against inferred type.
- RankedTensorType expectedType = ExtractSliceOp::inferResultType(
- sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
+ RankedTensorType expectedType =
+ ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
@@ -2697,8 +2693,7 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
- mixedStrides);
+ op.getType().getRank(), op.getSourceType(), mixedSizes);
}
};
@@ -2839,8 +2834,8 @@ static SliceVerificationResult verifyInsertSliceOp(
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
- RankedTensorType expected = ExtractSliceOp::inferResultType(
- dstType, staticOffsets, staticSizes, staticStrides);
+ RankedTensorType expected =
+ ExtractSliceOp::inferResultType(dstType, staticSizes);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2968,7 +2963,7 @@ class InsertSliceOpConstantArgumentFolder final
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
- mixedOffsets, mixedSizes, mixedStrides);
+ mixedSizes);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 7ec61c7df81cf..421f9ab7ceff7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
- srcType, extractSliceOp.getStaticOffsets(),
- extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+ srcType, extractSliceOp.getStaticSizes());
if (nonReducingExtractType != resultType)
return failure();
|
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/169313.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..ca2464f6272d3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -486,17 +486,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
- /// offsets, sizes and strides. Special sentinels encode the dynamic case.
+ /// sizes. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
/// number of sizes), drop as many size 1 as needed to produce an inferred type
@@ -509,15 +505,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690daa4f9e1..9e6c1e6036cba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides));
+ reassociation->size(), sliceOp.getSourceType(), sizes));
Location loc = sliceOp.getLoc();
Value newSlice = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 110bfdce72ea4..125db6249b23d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2293,9 +2293,9 @@ void ExtractSliceOp::getAsmResultNames(
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<int64_t> staticSizes) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
@@ -2307,11 +2307,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
}
// TODO: This uses neither offsets nor strides!
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
+
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
@@ -2329,11 +2330,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> sizes) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+ inferResultType(sourceRankedTensorType, sizes));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
@@ -2352,16 +2352,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
- SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<int64_t> staticSizes;
+ SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
- staticStrides);
+ desiredResultRank, sourceRankedTensorType, staticSizes);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2380,8 +2376,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
- sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
+ resultType = llvm::cast<RankedTensorType>(
+ ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2456,8 +2452,8 @@ LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
// Verify result type against inferred type.
- RankedTensorType expectedType = ExtractSliceOp::inferResultType(
- sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
+ RankedTensorType expectedType =
+ ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
@@ -2697,8 +2693,7 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
- mixedStrides);
+ op.getType().getRank(), op.getSourceType(), mixedSizes);
}
};
@@ -2839,8 +2834,8 @@ static SliceVerificationResult verifyInsertSliceOp(
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
- RankedTensorType expected = ExtractSliceOp::inferResultType(
- dstType, staticOffsets, staticSizes, staticStrides);
+ RankedTensorType expected =
+ ExtractSliceOp::inferResultType(dstType, staticSizes);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2968,7 +2963,7 @@ class InsertSliceOpConstantArgumentFolder final
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
- mixedOffsets, mixedSizes, mixedStrides);
+ mixedSizes);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 7ec61c7df81cf..421f9ab7ceff7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
- srcType, extractSliceOp.getStaticOffsets(),
- extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+ srcType, extractSliceOp.getStaticSizes());
if (nonReducingExtractType != resultType)
return failure();
|
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me, that for tensors, only the size matters for the result type on not the offsets and strides.
Nit: I wouldnt call a change to signature of a publicly exposed method NFC.
API are not "functional" in LLVM/MLIR: if we're not changing the behavior exposed by any pass we're "NFC" (otherwise all the refactoring we're doing couldn't be). A litmus test is: "do we need a test for this patch?", if not it better be NFC. |
The
offsetsandstridesarguments are neither used nor required -removed them and simplify this hook.