-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Turn memref/tensor.dim
reification into canonicalization pattern
#70897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Turn memref/tensor.dim
reification into canonicalization pattern
#70897
Conversation
…tern Instead of having a dedicated pass to fold `memref/tensor.dim` of ops that implement `ReifyRankedShapedTypeOpInterface`, turn the respective patterns into canonicalization patterns. This allows us to delete canonicalization patterns that do the same for specific ops. (Some of these canonicalization patterns do not have proper error checking; e.g., they crash when the dimension index is out-of-bounds.) This change also decouples the tensor/memref transforms build units a bit: there is now one fewer dependency on `tensor.dim` in `MemRef/Transforms/ResolveShapedTypeResultDims.cpp`. The canonicalization pattern is now part of `mlir/Interfaces/InferTypeOpInterface.h`. Also add a new `transform.tensor.resolve_ranked_shaped_type_result_dims` transform op. (`transform.memref.resolve_ranked_shaped_type_result_dims` no longer applies to `tensor.dim` ops.)
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) ChangesInstead of having a dedicated pass to fold This change also decouples the tensor/memref transforms build units a bit: there is now one fewer dependency on Also add a new Patch is 22.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70897.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index d7bd8410e360a76..7dd2a95f0e621a5 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -131,9 +131,10 @@ def ApplyFoldMemrefAliasOpsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
-def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect,
- "apply_patterns.memref.resolve_ranked_shaped_type_result_dims",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+def ApplyMemrefResolveRankedShapedTypeResultDimsPatternsOp
+ : Op<Transform_Dialect,
+ "apply_patterns.memref.resolve_ranked_shaped_type_result_dims",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collects patterns that resolve `memref.dim` operations with values that are
defined by operations that implement the `ReifyRankedShapedTypeOpInterface`,
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index f502aac79927094..5cc0b818de4c20f 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -57,12 +57,6 @@ std::unique_ptr<Pass> createFoldMemRefAliasOpsPass();
/// (identity) layout map.
std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
-/// Creates an operation pass to resolve `memref.dim` operations with values
-/// that are defined by operations that implement the
-/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
-/// operands.
-std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
-
/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
/// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`,
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index d7ee492b9e990e0..07bf42deabb26b8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -159,21 +159,6 @@ def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
let dependentDialects = ["affine::AffineDialect"];
}
-def ResolveRankedShapeTypeResultDims :
- Pass<"resolve-ranked-shaped-type-result-dims"> {
- let summary = "Resolve memref.dim of result values of ranked shape type";
- let description = [{
- The pass resolves memref.dim of result of operations that
- implement the `ReifyRankedShapedTypeOpInterface` in terms of
- shapes of its operands.
- }];
- let constructor =
- "mlir::memref::createResolveRankedShapeTypeResultDimsPass()";
- let dependentDialects = [
- "memref::MemRefDialect", "tensor::TensorDialect"
- ];
-}
-
def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values";
let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index a918f62cbc8db8f..50000691a2928de 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -45,13 +45,6 @@ void populateExpandOpsPatterns(RewritePatternSet &patterns);
/// ops into `patterns`.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
-/// Appends patterns that resolve `memref.dim` operations with values that are
-/// defined by operations that implement the
-/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
-/// operands.
-void populateResolveRankedShapedTypeResultDimsPatterns(
- RewritePatternSet &patterns);
-
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
/// terms of shapes of its input operands.
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 66c6021418b471c..af598a5b35fab32 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -99,6 +99,19 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyTensorResolveRankedShapedTypeResultDimsPatternsOp
+ : Op<Transform_Dialect,
+ "apply_patterns.tensor.resolve_ranked_shaped_type_result_dims",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects patterns that resolve `tensor.dim` operations with values that are
+ defined by operations that implement the `ReifyRankedShapedTypeOpInterface`,
+ in terms of shapes of its input operands.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.rewrite_as_constant",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 67de05b0cb4ff34..79720807a2e7b8d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
@@ -277,6 +278,57 @@ template <typename ConcreteType>
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};
} // namespace OpTrait
+
+namespace {
+/// Fold dim of an operation that implements ReifyRankedShapedTypeOpInterface.
+template <typename OpTy>
+struct FoldDimOfReifyRankedShapedTypeOp : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
+
+ LogicalResult matchAndRewrite(OpTy dimOp,
+ PatternRewriter &rewriter) const override {
+ OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
+ if (!dimValue)
+ return failure();
+ // Can fold only if the dimension is a constant.
+ std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
+ if (!dimIndex)
+ return failure();
+ // Reify result dimensions.
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
+ reifiedResultShapes)))
+ return rewriter.notifyMatchFailure(dimOp,
+ "failed to reify result shapes");
+ unsigned resultNumber = dimValue.getResultNumber();
+ // Do not apply pattern if the IR is invalid (dim out of bounds).
+ if (*dimIndex >= reifiedResultShapes[resultNumber].size())
+ return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
+ OpFoldResult dimSize = reifiedResultShapes[resultNumber][*dimIndex];
+ // If the dim size is a value, replace the op directly.
+ if (auto value = dimSize.dyn_cast<Value>()) {
+ rewriter.replaceOp(dimOp, value);
+ return success();
+ }
+ // Otherwise, materialize a constant value.
+ rewriter.replaceOp(dimOp, dimOp->getDialect()->materializeConstant(
+ rewriter, dimSize.get<Attribute>(),
+ rewriter.getIndexType(), dimOp->getLoc()));
+ return success();
+ }
+};
+} // namespace
+
+/// Populate `patterns` with a pattern that dim ops of type OpTy that operate
+/// on ops that implement ReifyRankedShapedTypeOpInterface.
+template <typename OpTy>
+void populateResolveRankedShapedTypeResultDimsPattern(
+ RewritePatternSet &patterns) {
+ patterns.insert<FoldDimOfReifyRankedShapedTypeOp<OpTy>>(
+ patterns.getContext());
+}
} // namespace mlir
#endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 8f19245efdba6c8..1d6897ebf3437aa 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -329,28 +329,11 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
return success();
}
};
-
-struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
- using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::DimOp dimOp,
- PatternRewriter &rewriter) const override {
- std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
- auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
- if (!allocTensorOp || !maybeConstantIndex)
- return failure();
- if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
- return failure();
- rewriter.replaceOp(
- dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
- return success();
- }
-};
} // namespace
void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *ctx) {
- results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
+ results.add<ReplaceStaticShapeDims>(ctx);
}
LogicalResult AllocTensorOp::reifyResultShapes(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 2e3610b7c08d9da..ea6239e78a66656 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -647,7 +647,8 @@ populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns);
- memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+ populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
+ populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
@@ -662,7 +663,8 @@ populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns);
- memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+ populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
+ populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..749802837186227 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1136,6 +1136,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape>(context);
+ populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(results);
}
// ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index eed29efcaaada88..d56fa102451c366 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -121,9 +121,9 @@ void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
memref::populateFoldMemRefAliasOpPatterns(patterns);
}
-void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
+void transform::ApplyMemrefResolveRankedShapedTypeResultDimsPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
- memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+ populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 0cb5931ce6bf9b9..9f3f33aadf93fb8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -24,7 +24,6 @@
namespace mlir {
namespace memref {
-#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
@@ -72,37 +71,6 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
return success();
}
};
-
-/// Fold dim of an operation that implements the InferShapedTypeOpInterface
-template <typename OpTy>
-struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
-
- LogicalResult matchAndRewrite(OpTy dimOp,
- PatternRewriter &rewriter) const override {
- OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
- if (!dimValue)
- return failure();
- std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
- if (!dimIndex)
- return failure();
-
- ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
- reifiedResultShapes)))
- return failure();
- unsigned resultNumber = dimValue.getResultNumber();
- // Do not apply pattern if the IR is invalid (dim out of bounds).
- if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
- return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
- Value replacement = getValueOrCreateConstantIndexOp(
- rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
- rewriter.replaceOp(dimOp, replacement);
- return success();
- }
-};
} // namespace
//===----------------------------------------------------------------------===//
@@ -110,11 +78,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
//===----------------------------------------------------------------------===//
namespace {
-struct ResolveRankedShapeTypeResultDimsPass final
- : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
- ResolveRankedShapeTypeResultDimsPass> {
- void runOnOperation() override;
-};
struct ResolveShapedTypeResultDimsPass final
: public memref::impl::ResolveShapedTypeResultDimsBase<
@@ -124,13 +87,6 @@ struct ResolveShapedTypeResultDimsPass final
} // namespace
-void memref::populateResolveRankedShapedTypeResultDimsPatterns(
- RewritePatternSet &patterns) {
- patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
- DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
- patterns.getContext());
-}
-
void memref::populateResolveShapedTypeResultDimsPatterns(
RewritePatternSet &patterns) {
// TODO: Move tensor::DimOp pattern to the Tensor dialect.
@@ -139,17 +95,13 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
patterns.getContext());
}
-void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
- return signalPassFailure();
-}
-
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
+ // TODO: `populateResolveRankedShapedTypeResultDimsPattern` does not really
+ // belong here.
+ populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
+ populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
@@ -157,7 +109,3 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
return std::make_unique<ResolveShapedTypeResultDimsPass>();
}
-
-std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
- return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
-}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f719cfed6b6dd30..29e623a06933cce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -606,6 +606,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+ populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(results);
}
//===----------------------------------------------------------------------===//
@@ -737,23 +738,6 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
}
};
-struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::DimOp dimOp,
- PatternRewriter &rewriter) const override {
- std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
- auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
- if (!emptyTensorOp || !maybeConstantIndex)
- return failure();
- if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
- return failure();
- rewriter.replaceOp(dimOp,
- emptyTensorOp.getDynamicSize(*maybeConstantIndex));
- return success();
- }
-};
-
/// Canonicalize
///
/// ```mlir
@@ -830,8 +814,8 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
- ReplaceEmptyTensorStaticShapeDims>(context);
+ results.add<FoldEmptyTensorWithCastOp, ReplaceEmptyTensorStaticShapeDims>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 3cec91389392246..d92f68712a9972b 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -118,6 +118,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
}
+void transform::ApplyTensorResolveRankedShapedTypeResultDimsPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
+}
+
void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
tensor::populateRewriteAsConstantPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemR...
[truncated]
|
I think we decided to keep this as an opt in long time back. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was moved from canonicalization to opt in a while ago. Is there a reason to move it back
Just a general cleanup. The main point was moving the pattern to ‘Interfaces’. Whether it’s a canonicalization pattern or not is actually not important for this cleanup. I’m curious though, do you remember why this is no longer a canonization pattern? |
Basically canonicalization are meant to be local changes, but resolving dimensions of ops is a full program change. You are essentially solving for shape constraints in the whole program. So it is a fixed point, but isn't a canonicalization in the narrow sense of it. It was an effort to move things out of canonicalization to allow for more control of program transformations |
Instead of having a dedicated pass to fold
memref/tensor.dim
of ops that implementReifyRankedShapedTypeOpInterface
, turn the respective patterns into canonicalization patterns. This allows us to delete canonicalization patterns that do the same for specific ops. (Some of these canonicalization patterns do not have proper error checking; e.g., they crash when the dimension index is out-of-bounds.)This change also decouples the tensor/memref transforms build units a bit: there is now one fewer dependency on
tensor.dim
inMemRef/Transforms/ResolveShapedTypeResultDims.cpp
. The canonicalization pattern is now part ofmlir/Interfaces/InferTypeOpInterface.h
.Also add a new
transform.tensor.resolve_ranked_shaped_type_result_dims
transform op. (transform.memref.resolve_ranked_shaped_type_result_dims
no longer applies totensor.dim
ops.)