diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h index c2335b9dd5a1be..0e583aab3dc468 100644 --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h @@ -18,34 +18,52 @@ class PDLByteCode; /// This class represents a frozen set of patterns that can be processed by a /// pattern applicator. This class is designed to enable caching pattern lists -/// such that they need not be continuously recomputed. +/// such that they need not be continuously recomputed. Note that all copies of +/// this class share the same compiled pattern list, allowing for a reduction in +/// the number of duplicated patterns that need to be created. class FrozenRewritePatternList { using NativePatternListT = std::vector>; public: /// Freeze the patterns held in `patterns`, and take ownership. + FrozenRewritePatternList(); FrozenRewritePatternList(OwningRewritePatternList &&patterns); - FrozenRewritePatternList(FrozenRewritePatternList &&patterns); + FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default; + FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default; + FrozenRewritePatternList & + operator=(const FrozenRewritePatternList &patterns) = default; + FrozenRewritePatternList & + operator=(FrozenRewritePatternList &&patterns) = default; ~FrozenRewritePatternList(); /// Return the native patterns held by this list. iterator_range> getNativePatterns() const { + const NativePatternListT &nativePatterns = impl->nativePatterns; return llvm::make_pointee_range(nativePatterns); } /// Return the compiled PDL bytecode held by this list. Returns null if /// there are no PDL patterns within the list. const detail::PDLByteCode *getPDLByteCode() const { - return pdlByteCode.get(); + return impl->pdlByteCode.get(); } private: - /// The set of. - std::vector> nativePatterns; + /// The internal implementation of the frozen pattern list. + struct Impl { + /// The set of native C++ rewrite patterns. + NativePatternListT nativePatterns; - /// The bytecode containing the compiled PDL patterns. - std::unique_ptr pdlByteCode; + /// The bytecode containing the compiled PDL patterns. + std::unique_ptr pdlByteCode; + }; + + /// A pointer to the internal pattern list. This uses a shared_ptr to avoid + /// the need to compile the same pattern list multiple times. For example, + /// during multi-threaded pass execution, all copies of a pass can share the + /// same pattern list. + std::shared_ptr impl; }; } // end namespace mlir diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp index 60f6dcea88f203..40d7fcde8f33d4 100644 --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -50,12 +50,16 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { // FrozenRewritePatternList //===----------------------------------------------------------------------===// +FrozenRewritePatternList::FrozenRewritePatternList() + : impl(std::make_shared()) {} + FrozenRewritePatternList::FrozenRewritePatternList( OwningRewritePatternList &&patterns) - : nativePatterns(std::move(patterns.getNativePatterns())) { - PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); + : impl(std::make_shared()) { + impl->nativePatterns = std::move(patterns.getNativePatterns()); // Generate the bytecode for the PDL patterns if any were provided. + PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); ModuleOp pdlModule = pdlPatterns.getModule(); if (!pdlModule) return; @@ -64,14 +68,9 @@ FrozenRewritePatternList::FrozenRewritePatternList( "failed to lower PDL pattern module to the PDL Interpreter"); // Generate the pdl bytecode. - pdlByteCode = std::make_unique( + impl->pdlByteCode = std::make_unique( pdlModule, pdlPatterns.takeConstraintFunctions(), pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions()); } -FrozenRewritePatternList::FrozenRewritePatternList( - FrozenRewritePatternList &&patterns) - : nativePatterns(std::move(patterns.nativePatterns)), - pdlByteCode(std::move(patterns.pdlByteCode)) {} - FrozenRewritePatternList::~FrozenRewritePatternList() {}