diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 457a4b11d816d..4679d98719224 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -450,11 +450,11 @@ class RewritePatternMatcher { /// Note: These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -bool applyPatternsGreedily(Operation *op, - const OwningRewritePatternList &patterns); +bool applyPatternsAndFoldGreedily(Operation *op, + const OwningRewritePatternList &patterns); /// Rewrite the given regions, which must be isolated from above. -bool applyPatternsGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns); +bool applyPatternsAndFoldGreedily(MutableArrayRef regions, + const OwningRewritePatternList &patterns); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index c3f3c04d91967..9391466ffdc92 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -268,7 +268,7 @@ class LowerGpuOpsToNVVMOpsPass // which need to be lowered further, which is not supported by a single // conversion pass. populateGpuRewritePatterns(m.getContext(), patterns); - applyPatternsGreedily(m, patterns); + applyPatternsAndFoldGreedily(m, patterns); patterns.clear(); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 1312e0c36f809..6d9974233a9f0 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -170,7 +170,7 @@ void SPIRVLegalization::runOnOperation() { OwningRewritePatternList patterns; auto *context = &getContext(); populateStdLegalizationPatternsForSPIRVLowering(context, patterns); - applyPatternsGreedily(getOperation()->getRegions(), patterns); + applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); } std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 38822fa79458c..b7c4a57a78ba3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1244,7 +1244,7 @@ void LowerVectorToLLVMPass::runOnOperation() { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), patterns); } // Convert to the LLVM IR dialect. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index bff60f417082f..c861b214d3b34 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -225,6 +225,6 @@ void AffineDataCopyGeneration::runOnFunction() { OwningRewritePatternList patterns; AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - applyPatternsGreedily(f, std::move(patterns)); + applyPatternsAndFoldGreedily(f, std::move(patterns)); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index ee9996b287084..97f684cd16ab0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -572,7 +572,7 @@ struct FusionOfTensorOpsPass OwningRewritePatternList patterns; Operation *op = getOperation(); patterns.insert(op->getContext()); - applyPatternsGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); }; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index e85a67a9f7eb1..48df0ac3ea2ac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -688,7 +688,7 @@ static void lowerLinalgToLoopsImpl(Operation *op, MLIRContext *context) { AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); // Just apply the patterns greedily. - applyPatternsGreedily(op, patterns); + applyPatternsAndFoldGreedily(op, patterns); } namespace { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp index 8f9f55d175a9c..2ff23123b474e 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -98,7 +98,7 @@ void ConvertConstPass::runOnFunction() { auto func = getFunction(); auto *context = &getContext(); patterns.insert(context); - applyPatternsGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, patterns); } std::unique_ptr> mlir::quant::createConvertConstPass() { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp index 2cb077a25bb1e..bafc48e11ff06 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -131,7 +131,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { auto ctx = func.getContext(); patterns.insert( ctx, &hadFailure); - applyPatternsGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, patterns); if (hadFailure) signalPassFailure(); } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index c46a8b9fa31e4..9b028bfa2525d 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -32,7 +32,7 @@ struct Canonicalizer : public CanonicalizerBase { op->getCanonicalizationPatterns(patterns, context); Operation *op = getOperation(); - applyPatternsGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 8ee4996bd03f8..f797ca8f64847 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -541,7 +541,7 @@ static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, // Apply the canonicalization patterns to this region. auto *node = nodesToCanonicalize[index]; - applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); + applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); // Make sure to reset the order ID for the diagnostic handler, as this // thread may be used in a different context. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 258df35d9b4d0..80ad143ce0d3c 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,8 +37,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter { worklist.reserve(64); } - /// Perform the rewrites. Return true if the rewrite converges in - /// `maxIterations`. + /// Perform the rewrites while folding and erasing any dead ops. Return true + /// if the rewrite converges in `maxIterations`. bool simplify(MutableArrayRef regions, int maxIterations); void addToWorklist(Operation *op) { @@ -133,7 +133,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { }; } // end anonymous namespace -/// Perform the rewrites. +/// Perform the rewrites while folding and erasing any dead ops. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { // Add the given operation to the worklist. @@ -213,14 +213,14 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, /// the result operation regions. /// Note: This does not apply patterns to the top-level operation itself. /// -bool mlir::applyPatternsGreedily(Operation *op, - const OwningRewritePatternList &patterns) { - return applyPatternsGreedily(op->getRegions(), patterns); +bool mlir::applyPatternsAndFoldGreedily( + Operation *op, const OwningRewritePatternList &patterns) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns); } /// Rewrite the given regions, which must be isolated from above. -bool mlir::applyPatternsGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns) { +bool mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const OwningRewritePatternList &patterns) { if (regions.empty()) return true; diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp index 7c0052dd9e688..6c8b546e6aef3 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -97,7 +97,7 @@ void TestAffineDataCopy::runOnFunction() { OwningRewritePatternList patterns; AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - applyPatternsGreedily(getFunction(), std::move(patterns)); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 90b34d9fe70f5..d36eb985512ac 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -46,7 +46,7 @@ struct TestPatternDriver : public PassWrapper { // Verify named pattern is generated with expected name. patterns.insert(&getContext()); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; } // end anonymous namespace diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp index 9a2bcc2923799..c043d0f02f8d0 100644 --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -22,7 +22,7 @@ struct TestAllReduceLoweringPass void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); - applyPatternsGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), patterns); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp index 6f49fabc192a8..e32f4d3dd6c50 100644 --- a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp +++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp @@ -36,7 +36,7 @@ struct DeclarativeTransforms SubViewOp::getCanonicalizationPatterns(patterns, context); ViewOp::getCanonicalizationPatterns(patterns, context); populateWithGenerated(context, &patterns); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; } // end anonymous namespace diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index 85300f981f1e1..7fc1138ff8d4d 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -40,7 +40,7 @@ void TestLinalgTransforms::runOnFunction() { // Add the generated patterns to the list. linalg::populateWithGenerated(&getContext(), &patterns); - applyPatternsGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, patterns); // Drop the marker. funcOp.walk([](LinalgOp op) { diff --git a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp index b1c02bdd0adf1..dc9b5c8d66cd7 100644 --- a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp +++ b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp @@ -23,7 +23,7 @@ struct TestVectorToLoopsPass OwningRewritePatternList patterns; auto *context = &getContext(); populateVectorToAffineLoopsConversionPatterns(context, patterns); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 808fcd21d3316..c57540bc2ef70 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -28,7 +28,7 @@ struct TestVectorToVectorConversion populateWithGenerated(context, &patterns); populateVectorToVectorCanonicalizationPatterns(patterns, context); populateVectorToVectorTransformationPatterns(patterns, context); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; @@ -37,7 +37,7 @@ struct TestVectorSlicesConversion void runOnFunction() override { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; @@ -57,7 +57,7 @@ struct TestVectorContractionConversion VectorTransformsOptions options{ /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } };