Skip to content

Commit c582688

Browse files
authored
[MLIR][tensor] Simplify ExtractSliceOp::inferResultType (nfc) (#169313)
The `offsets` and `strides` arguments are neither used nor required - removed them and simplify this hook.
1 parent b8ef25a commit c582688

File tree

4 files changed

+28
-43
lines changed

4 files changed

+28
-43
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -490,17 +490,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
490490

491491
/// An extract_slice result type can be inferred, when it is not
492492
/// rank-reduced, from the source type and the static representation of
493-
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
493+
/// sizes. Special sentinels encode the dynamic case.
494494
static RankedTensorType inferResultType(
495495
RankedTensorType sourceTensorType,
496-
ArrayRef<int64_t> staticOffsets,
497-
ArrayRef<int64_t> staticSizes,
498-
ArrayRef<int64_t> staticStrides);
496+
ArrayRef<int64_t> staticSizes);
499497
static RankedTensorType inferResultType(
500498
RankedTensorType sourceTensorType,
501-
ArrayRef<OpFoldResult> staticOffsets,
502-
ArrayRef<OpFoldResult> staticSizes,
503-
ArrayRef<OpFoldResult> staticStrides);
499+
ArrayRef<OpFoldResult> staticSizes);
504500

505501
/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
506502
/// number of sizes), drop as many size 1 as needed to produce an inferred type
@@ -513,15 +509,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
513509
static RankedTensorType inferCanonicalRankReducedResultType(
514510
unsigned resultRank,
515511
RankedTensorType sourceRankedTensorType,
516-
ArrayRef<int64_t> staticOffsets,
517-
ArrayRef<int64_t> staticSizes,
518-
ArrayRef<int64_t> staticStrides);
512+
ArrayRef<int64_t> staticSizes);
519513
static RankedTensorType inferCanonicalRankReducedResultType(
520514
unsigned resultRank,
521515
RankedTensorType sourceRankedTensorType,
522-
ArrayRef<OpFoldResult> staticOffsets,
523-
ArrayRef<OpFoldResult> staticSizes,
524-
ArrayRef<OpFoldResult> staticStrides);
516+
ArrayRef<OpFoldResult> staticSizes);
525517

526518
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
527519
/// and `static_strides` attributes.

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
747747
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
748748
auto rankReducedType = cast<RankedTensorType>(
749749
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
750-
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
751-
strides));
750+
reassociation->size(), sliceOp.getSourceType(), sizes));
752751

753752
Location loc = sliceOp.getLoc();
754753
Value newSlice = tensor::ExtractSliceOp::create(

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames(
22912291
/// An extract_slice result type can be inferred, when it is not
22922292
/// rank-reduced, from the source type and the static representation of
22932293
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2294-
RankedTensorType ExtractSliceOp::inferResultType(
2295-
RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2296-
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2294+
RankedTensorType
2295+
ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2296+
ArrayRef<int64_t> staticSizes) {
22972297
// An extract_slice op may specify only a leading subset of offset/sizes/
22982298
// strides in which case we complete with offset=0, sizes from memref type
22992299
// and strides=1.
@@ -2305,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
23052305
}
23062306

23072307
// TODO: This uses neither offsets nor strides!
2308-
RankedTensorType ExtractSliceOp::inferResultType(
2309-
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2310-
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2308+
RankedTensorType
2309+
ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2310+
ArrayRef<OpFoldResult> sizes) {
23112311
SmallVector<int64_t> staticSizes;
23122312
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2313+
23132314
assert(static_cast<int64_t>(staticSizes.size()) ==
23142315
sourceTensorType.getRank() &&
23152316
"unexpected staticSizes not equal to rank of source");
@@ -2327,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
23272328
/// To disambiguate, this function always drops the first 1 sizes occurrences.
23282329
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
23292330
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2330-
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2331-
ArrayRef<int64_t> strides) {
2331+
ArrayRef<int64_t> sizes) {
23322332
// Type inferred in the absence of rank-reducing behavior.
23332333
auto inferredType = llvm::cast<RankedTensorType>(
2334-
inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2334+
inferResultType(sourceRankedTensorType, sizes));
23352335
int rankDiff = inferredType.getRank() - desiredResultRank;
23362336
if (rankDiff > 0) {
23372337
auto shape = inferredType.getShape();
@@ -2350,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
23502350

23512351
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
23522352
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2353-
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2354-
ArrayRef<OpFoldResult> strides) {
2355-
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2356-
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2357-
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2353+
ArrayRef<OpFoldResult> sizes) {
2354+
SmallVector<int64_t> staticSizes;
2355+
SmallVector<Value> dynamicSizes;
23582356
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2359-
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
23602357
return ExtractSliceOp::inferCanonicalRankReducedResultType(
2361-
desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2362-
staticStrides);
2358+
desiredResultRank, sourceRankedTensorType, staticSizes);
23632359
}
23642360

23652361
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2378,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
23782374
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
23792375
// Structuring implementation this way avoids duplication between builders.
23802376
if (!resultType) {
2381-
resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2382-
sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2377+
resultType = llvm::cast<RankedTensorType>(
2378+
ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
23832379
}
23842380
result.addAttributes(attrs);
23852381
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2454,8 +2450,8 @@ LogicalResult ExtractSliceOp::verify() {
24542450
RankedTensorType sourceType = getSourceType();
24552451

24562452
// Verify result type against inferred type.
2457-
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2458-
sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2453+
RankedTensorType expectedType =
2454+
ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
24592455
SliceVerificationResult result = isRankReducedType(expectedType, getType());
24602456
if (result != SliceVerificationResult::Success)
24612457
return produceSliceErrorMsg(result, *this, expectedType);
@@ -2695,8 +2691,7 @@ struct SliceReturnTypeCanonicalizer {
26952691
ArrayRef<OpFoldResult> mixedSizes,
26962692
ArrayRef<OpFoldResult> mixedStrides) {
26972693
return ExtractSliceOp::inferCanonicalRankReducedResultType(
2698-
op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2699-
mixedStrides);
2694+
op.getType().getRank(), op.getSourceType(), mixedSizes);
27002695
}
27012696
};
27022697

@@ -2837,8 +2832,8 @@ static SliceVerificationResult verifyInsertSliceOp(
28372832
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
28382833
// insert_slice is the inverse of extract_slice, use the same type
28392834
// inference.
2840-
RankedTensorType expected = ExtractSliceOp::inferResultType(
2841-
dstType, staticOffsets, staticSizes, staticStrides);
2835+
RankedTensorType expected =
2836+
ExtractSliceOp::inferResultType(dstType, staticSizes);
28422837
if (expectedType)
28432838
*expectedType = expected;
28442839
return isRankReducedType(expected, srcType);
@@ -2966,7 +2961,7 @@ class InsertSliceOpConstantArgumentFolder final
29662961
// Create the new op in canonical form.
29672962
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
29682963
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2969-
mixedOffsets, mixedSizes, mixedStrides);
2964+
mixedSizes);
29702965
Value toInsert = insertSliceOp.getSource();
29712966
if (sourceType != insertSliceOp.getSourceType()) {
29722967
OpBuilder::InsertionGuard g(rewriter);

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
3737
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
3838
// has no rank-reduction anymore are supported at the moment.
3939
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40-
srcType, extractSliceOp.getStaticOffsets(),
41-
extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
40+
srcType, extractSliceOp.getStaticSizes());
4241
if (nonReducingExtractType != resultType)
4342
return failure();
4443

0 commit comments

Comments
 (0)