Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -486,17 +486,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",

/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
/// sizes. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
ArrayRef<int64_t> staticSizes);
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
ArrayRef<OpFoldResult> staticSizes);

/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
/// number of sizes), drop as many size 1 as needed to produce an inferred type
Expand All @@ -509,15 +505,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
ArrayRef<int64_t> staticSizes);
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
ArrayRef<OpFoldResult> staticSizes);

/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
strides));
reassociation->size(), sliceOp.getSourceType(), sizes));

Location loc = sliceOp.getLoc();
Value newSlice = tensor::ExtractSliceOp::create(
Expand Down
47 changes: 21 additions & 26 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2293,9 +2293,9 @@ void ExtractSliceOp::getAsmResultNames(
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
RankedTensorType
ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
ArrayRef<int64_t> staticSizes) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
Expand All @@ -2307,11 +2307,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
}

// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
RankedTensorType
ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);

assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
Expand All @@ -2329,11 +2330,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
ArrayRef<int64_t> sizes) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
inferResultType(sourceRankedTensorType, offsets, sizes, strides));
inferResultType(sourceRankedTensorType, sizes));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
Expand All @@ -2352,16 +2352,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(

RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
staticStrides);
desiredResultRank, sourceRankedTensorType, staticSizes);
}

/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
Expand All @@ -2380,8 +2376,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
resultType = llvm::cast<RankedTensorType>(
ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
Expand Down Expand Up @@ -2456,8 +2452,8 @@ LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();

// Verify result type against inferred type.
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
RankedTensorType expectedType =
ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
Expand Down Expand Up @@ -2697,8 +2693,7 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
mixedStrides);
op.getType().getRank(), op.getSourceType(), mixedSizes);
}
};

Expand Down Expand Up @@ -2839,8 +2834,8 @@ static SliceVerificationResult verifyInsertSliceOp(
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
RankedTensorType expected = ExtractSliceOp::inferResultType(
dstType, staticOffsets, staticSizes, staticStrides);
RankedTensorType expected =
ExtractSliceOp::inferResultType(dstType, staticSizes);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
Expand Down Expand Up @@ -2968,7 +2963,7 @@ class InsertSliceOpConstantArgumentFolder final
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
mixedOffsets, mixedSizes, mixedStrides);
mixedSizes);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
srcType, extractSliceOp.getStaticOffsets(),
extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
srcType, extractSliceOp.getStaticSizes());
if (nonReducingExtractType != resultType)
return failure();

Expand Down