diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index e1a136c7e65bd..71dbe9fb24cd8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -19,7 +19,7 @@ namespace mlir { class BufferizeTypeConverter; -class FrozenRewritePatternList; +class FrozenRewritePatternSet; namespace linalg { @@ -964,8 +964,8 @@ class ConvOpVectorization : public OpRewritePattern { //===----------------------------------------------------------------------===// /// Helper function to allow applying rewrite patterns, interleaved with more /// global transformations, in a staged fashion: -/// 1. the first stage consists of a list of FrozenRewritePatternList. Each -/// FrozenRewritePatternList in this list is applied once, in order. +/// 1. the first stage consists of a list of FrozenRewritePatternSet. Each +/// FrozenRewritePatternSet in this list is applied once, in order. /// 2. the second stage consists of a single OwningRewritePattern that is /// applied greedily until convergence. /// 3. the third stage consists of applying a lambda, generally used for @@ -973,8 +973,8 @@ class ConvOpVectorization : public OpRewritePattern { /// transformations where patterns can be ordered and applied at a finer /// granularity than a sequence of traditional compiler passes. LogicalResult applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const FrozenRewritePatternList &stage2Patterns, + Operation *op, ArrayRef stage1Patterns, + const FrozenRewritePatternSet &stage2Patterns, function_ref stage3Lambda = nullptr); //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 514b7ae06938d..115ad5f039bc0 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -894,7 +894,7 @@ class RewritePatternSet { PDLPatternModule pdlPatterns; }; -// TODO: RewritePatternSet is soft-deprecated and will be removed in the +// TODO: OwningRewritePatternList is soft-deprecated and will be removed in the // future. using OwningRewritePatternList = RewritePatternSet; diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h similarity index 71% rename from mlir/include/mlir/Rewrite/FrozenRewritePatternList.h rename to mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h index a20030cd08da1..554bfd217534f 100644 --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h @@ -1,4 +1,4 @@ -//===- FrozenRewritePatternList.h - FrozenRewritePatternList ----*- C++ -*-===// +//===- FrozenRewritePatternSet.h --------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H -#define MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H +#ifndef MLIR_REWRITE_FROZENREWRITEPATTERNSET_H +#define MLIR_REWRITE_FROZENREWRITEPATTERNSET_H #include "mlir/IR/PatternMatch.h" @@ -21,20 +21,20 @@ class PDLByteCode; /// such that they need not be continuously recomputed. Note that all copies of /// this class share the same compiled pattern list, allowing for a reduction in /// the number of duplicated patterns that need to be created. -class FrozenRewritePatternList { +class FrozenRewritePatternSet { using NativePatternListT = std::vector>; public: /// Freeze the patterns held in `patterns`, and take ownership. - FrozenRewritePatternList(); - FrozenRewritePatternList(RewritePatternSet &&patterns); - FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default; - FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default; - FrozenRewritePatternList & - operator=(const FrozenRewritePatternList &patterns) = default; - FrozenRewritePatternList & - operator=(FrozenRewritePatternList &&patterns) = default; - ~FrozenRewritePatternList(); + FrozenRewritePatternSet(); + FrozenRewritePatternSet(RewritePatternSet &&patterns); + FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default; + FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default; + FrozenRewritePatternSet & + operator=(const FrozenRewritePatternSet &patterns) = default; + FrozenRewritePatternSet & + operator=(FrozenRewritePatternSet &&patterns) = default; + ~FrozenRewritePatternSet(); /// Return the native patterns held by this list. iterator_range> @@ -66,6 +66,10 @@ class FrozenRewritePatternList { std::shared_ptr impl; }; +// TODO: FrozenRewritePatternList is soft-deprecated and will be removed in the +// future. +using FrozenRewritePatternList = FrozenRewritePatternSet; + } // end namespace mlir -#endif // MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H +#endif // MLIR_REWRITE_FROZENREWRITEPATTERNSET_H diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h index 9d197175b47d7..9314496ecda15 100644 --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -14,7 +14,7 @@ #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H #define MLIR_REWRITE_PATTERNAPPLICATOR_H -#include "mlir/Rewrite/FrozenRewritePatternList.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" namespace mlir { class PatternRewriter; @@ -33,7 +33,7 @@ class PatternApplicator { /// `impossibleToMatch`. using CostModel = function_ref; - explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList); + explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList); ~PatternApplicator(); /// Attempt to match and rewrite the given op with any pattern, allowing a @@ -65,7 +65,7 @@ class PatternApplicator { private: /// The list that owns the patterns used within this applicator. - const FrozenRewritePatternList &frozenPatternList; + const FrozenRewritePatternSet &frozenPatternList; /// The set of patterns to match for each operation, stable sorted by benefit. DenseMap> patterns; /// The set of patterns that may match against any operation type, stable diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index ae86b2679eb3c..7ebd07d8cb421 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -13,7 +13,7 @@ #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_ #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ -#include "mlir/Rewrite/FrozenRewritePatternList.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringMap.h" @@ -842,11 +842,11 @@ class ConversionTarget { /// the `unconvertedOps` set will not necessarily be complete.) LogicalResult applyPartialConversion(ArrayRef ops, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet *unconvertedOps = nullptr); LogicalResult applyPartialConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet *unconvertedOps = nullptr); /// Apply a complete conversion on the given operations, and all nested @@ -855,9 +855,9 @@ applyPartialConversion(Operation *op, ConversionTarget &target, /// within 'ops'. LogicalResult applyFullConversion(ArrayRef ops, ConversionTarget &target, - const FrozenRewritePatternList &patterns); + const FrozenRewritePatternSet &patterns); LogicalResult applyFullConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternList &patterns); + const FrozenRewritePatternSet &patterns); /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully @@ -869,10 +869,10 @@ LogicalResult applyFullConversion(Operation *op, ConversionTarget &target, /// the regions nested within 'ops'. LogicalResult applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet &convertedOps); LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet &convertedOps); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index aa06c02c2b9ed..cbbe5c10948db 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -14,7 +14,7 @@ #ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ #define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ -#include "mlir/Rewrite/FrozenRewritePatternList.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" namespace mlir { @@ -35,25 +35,25 @@ namespace mlir { /// before attempting to match any of the provided patterns. LogicalResult applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, bool useTopDownTraversal = true); /// Rewrite the regions of the specified operation, with a user-provided limit /// on iterations to attempt before reaching convergence. LogicalResult applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternList &patterns, + Operation *op, const FrozenRewritePatternSet &patterns, unsigned maxIterations, bool useTopDownTraversal = true); /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, bool useTopDownTraversal = true); /// Rewrite the given regions, with a user-provided limit on iterations to /// attempt before reaching convergence. LogicalResult applyPatternsAndFoldGreedily( - MutableArrayRef regions, const FrozenRewritePatternList &patterns, + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, unsigned maxIterations, bool useTopDownTraversal = true); /// Applies the specified patterns on `op` alone while also trying to fold it, @@ -62,7 +62,7 @@ LogicalResult applyPatternsAndFoldGreedily( /// was folded away or erased as a result of becoming dead. Note: This does not /// apply any patterns recursively to the regions of `op`. LogicalResult applyOpPatternsAndFold(Operation *op, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, bool *erased = nullptr); } // end namespace mlir diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 1d4d9fe84cf1e..1dedc2c39d8f6 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -92,7 +92,7 @@ static LogicalResult applyPatterns(FuncOp func) { RewritePatternSet patterns(func.getContext()); patterns.add(func.getContext()); - FrozenRewritePatternList frozen(std::move(patterns)); + FrozenRewritePatternSet frozen(std::move(patterns)); return applyPartialConversion(func, target, frozen); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index cd966d404a47a..851ec5051a6be 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -230,7 +230,7 @@ void AffineDataCopyGeneration::runOnFunction() { RewritePatternSet patterns(&getContext()); AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (Operation *op : copyOps) (void)applyOpPatternsAndFold(op, frozenPatterns); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index c3ec72f51b3fa..8f59074e6b791 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -83,7 +83,7 @@ void SimplifyAffineStructures::runOnFunction() { AffineForOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); func.walk([&](Operation *op) { for (auto attr : op->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 8e2645a2d44ae..522cfd7fca950 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -191,7 +191,7 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { RewritePatternSet patterns(ifOp.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); bool erased; - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); if (erased) { if (folded) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp index cd4f65953c0a1..e31a6b5210e3c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -31,7 +31,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { // Emplace patterns one at a time while also maintaining a simple chained // state transition. unsigned stepCount = 0; - SmallVector stage1Patterns; + SmallVector stage1Patterns; auto zeroState = Identifier::get(std::to_string(stepCount), context); auto currentState = zeroState; for (const std::unique_ptr &t : transformationSequence) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 965275dc2bcc3..4202cb2685764 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -479,8 +479,8 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( } LogicalResult mlir::linalg::applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const FrozenRewritePatternList &stage2Patterns, + Operation *op, ArrayRef stage1Patterns, + const FrozenRewritePatternSet &stage2Patterns, function_ref stage3Lambda) { unsigned iteration = 0; (void)iteration; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp index 87aa623b7abc9..372295a986afc 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -106,7 +106,7 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() { // TODO: Change the type for the indirect users such as spv.Load, spv.Store, // spv.FunctionCall and so on. - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (auto spirvModule : module.getOps()) if (failed(applyFullConversion(spirvModule, target, frozenPatterns))) signalPassFailure(); diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt index 5822789cc9162..76bf64944d50c 100644 --- a/mlir/lib/Rewrite/CMakeLists.txt +++ b/mlir/lib/Rewrite/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_library(MLIRRewrite ByteCode.cpp - FrozenRewritePatternList.cpp + FrozenRewritePatternSet.cpp PatternApplicator.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp similarity index 87% rename from mlir/lib/Rewrite/FrozenRewritePatternList.cpp rename to mlir/lib/Rewrite/FrozenRewritePatternSet.cpp index b61307b81b9f9..9c81363f13f24 100644 --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp @@ -1,4 +1,4 @@ -//===- FrozenRewritePatternList.cpp - Frozen Pattern List -------*- C++ -*-===// +//===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Rewrite/FrozenRewritePatternList.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "ByteCode.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" @@ -47,13 +47,13 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { } //===----------------------------------------------------------------------===// -// FrozenRewritePatternList +// FrozenRewritePatternSet //===----------------------------------------------------------------------===// -FrozenRewritePatternList::FrozenRewritePatternList() +FrozenRewritePatternSet::FrozenRewritePatternSet() : impl(std::make_shared()) {} -FrozenRewritePatternList::FrozenRewritePatternList(RewritePatternSet &&patterns) +FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns) : impl(std::make_shared()) { impl->nativePatterns = std::move(patterns.getNativePatterns()); @@ -72,4 +72,4 @@ FrozenRewritePatternList::FrozenRewritePatternList(RewritePatternSet &&patterns) pdlPatterns.takeRewriteFunctions()); } -FrozenRewritePatternList::~FrozenRewritePatternList() {} +FrozenRewritePatternSet::~FrozenRewritePatternSet() {} diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index 5032f02032571..3db598883360a 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -19,7 +19,7 @@ using namespace mlir; using namespace mlir::detail; PatternApplicator::PatternApplicator( - const FrozenRewritePatternList &frozenPatternList) + const FrozenRewritePatternSet &frozenPatternList) : frozenPatternList(frozenPatternList) { if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { mutableByteCodeState = std::make_unique(); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 5b6edf9894ab3..2d987778f22f2 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -35,7 +35,7 @@ struct Canonicalizer : public CanonicalizerBase { (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); } - FrozenRewritePatternList patterns; + FrozenRewritePatternSet patterns; }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d6037f563f874..41d3eabb07eaf 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1506,7 +1506,7 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(ConversionTarget &targetInfo, - const FrozenRewritePatternList &patterns); + const FrozenRewritePatternSet &patterns); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1602,7 +1602,7 @@ class OperationLegalizer { } // namespace OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, - const FrozenRewritePatternList &patterns) + const FrozenRewritePatternSet &patterns) : target(targetInfo), applicator(patterns) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. @@ -2125,7 +2125,7 @@ enum OpConversionMode { // conversion mode. struct OperationConverter { explicit OperationConverter(ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, OpConversionMode mode, DenseSet *trackedOps = nullptr) : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} @@ -2755,7 +2755,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const LogicalResult mlir::applyPartialConversion(ArrayRef ops, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet *unconvertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Partial, unconvertedOps); @@ -2763,7 +2763,7 @@ mlir::applyPartialConversion(ArrayRef ops, } LogicalResult mlir::applyPartialConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet *unconvertedOps) { return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, unconvertedOps); @@ -2774,13 +2774,13 @@ mlir::applyPartialConversion(Operation *op, ConversionTarget &target, /// operation fails. LogicalResult mlir::applyFullConversion(ArrayRef ops, ConversionTarget &target, - const FrozenRewritePatternList &patterns) { + const FrozenRewritePatternSet &patterns) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); return opConverter.convertOperations(ops); } LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternList &patterns) { + const FrozenRewritePatternSet &patterns) { return applyFullConversion(llvm::makeArrayRef(op), target, patterns); } @@ -2793,7 +2793,7 @@ mlir::applyFullConversion(Operation *op, ConversionTarget &target, LogicalResult mlir::applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet &convertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); @@ -2801,7 +2801,7 @@ mlir::applyAnalysisConversion(ArrayRef ops, } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, DenseSet &convertedOps) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, convertedOps); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index c4b5fe043e489..f28f228737a84 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,7 +37,7 @@ namespace { class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, bool useTopDownTraversal) : PatternRewriter(ctx), matcher(patterns), folder(ctx), useTopDownTraversal(useTopDownTraversal) { @@ -242,13 +242,13 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, /// LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, bool useTopDownTraversal) { return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations, useTopDownTraversal); } LogicalResult mlir::applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternList &patterns, + Operation *op, const FrozenRewritePatternSet &patterns, unsigned maxIterations, bool useTopDownTraversal) { return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations, useTopDownTraversal); @@ -256,13 +256,13 @@ LogicalResult mlir::applyPatternsAndFoldGreedily( /// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternList &patterns, + const FrozenRewritePatternSet &patterns, bool useTopDownTraversal) { return applyPatternsAndFoldGreedily( regions, patterns, maxPatternMatchIterations, useTopDownTraversal); } LogicalResult mlir::applyPatternsAndFoldGreedily( - MutableArrayRef regions, const FrozenRewritePatternList &patterns, + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, unsigned maxIterations, bool useTopDownTraversal) { if (regions.empty()) return success(); @@ -298,7 +298,7 @@ namespace { class OpPatternRewriteDriver : public PatternRewriter { public: explicit OpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternList &patterns) + const FrozenRewritePatternSet &patterns) : PatternRewriter(ctx), matcher(patterns), folder(ctx) { // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); @@ -382,7 +382,7 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, /// folding. `erased` is set to true if the op is erased as a result of being /// folded, replaced, or dead. LogicalResult mlir::applyOpPatternsAndFold( - Operation *op, const FrozenRewritePatternList &patterns, bool *erased) { + Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { // Start the pattern driver. OpPatternRewriteDriver driver(op->getContext(), patterns); bool opErased; diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp index 55464283ff7de..7bf298904780a 100644 --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -61,7 +61,7 @@ void TestConvVectorization::runOnOperation() { SmallVector stage1Patterns; linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); - SmallVector frozenStage1Patterns; + SmallVector frozenStage1Patterns; llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); RewritePatternSet stage2Patterns = diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp index 23e6e0056627e..e752c46ecea91 100644 --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -184,7 +184,7 @@ struct TestLinalgGreedyFusion RewritePatternSet patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); patterns.add(context); - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); while (succeeded(fuseLinalgOpsGreedily(getFunction()))) { (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); PassManager pm(context); diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index a9765ce8c9a46..0bb46455b9ca9 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -478,9 +478,9 @@ applyMatmulToVectorPatterns(FuncOp funcOp, fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), stage1Patterns); } - SmallVector frozenStage1Patterns; + SmallVector frozenStage1Patterns; llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); - FrozenRewritePatternList stage2Patterns = + FrozenRewritePatternSet stage2Patterns = getLinalgTilingCanonicalizationPatterns(ctx); (void)applyStagedPatterns(funcOp, frozenStage1Patterns, std::move(stage2Patterns)); @@ -505,7 +505,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) { static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { RewritePatternSet foldPattern(funcOp.getContext()); foldPattern.add(funcOp.getContext()); - FrozenRewritePatternList frozenPatterns(std::move(foldPattern)); + FrozenRewritePatternSet frozenPatterns(std::move(foldPattern)); // Explicitly walk and apply the pattern locally to avoid more general folding // on the rest of the IR. diff --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp index 9461e2f0ff8b3..0d2f74ae4890a 100644 --- a/mlir/unittests/Rewrite/PatternBenefit.cpp +++ b/mlir/unittests/Rewrite/PatternBenefit.cpp @@ -60,7 +60,7 @@ TEST(PatternBenefitTest, BenefitOrder) { patterns.add(&context, &called1); patterns.add(&called2); - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); PatternApplicator pa(frozenPatterns); pa.applyDefaultCostModel();