diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index 92fd246a135886..6daccaef7cd3bf 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -302,7 +302,8 @@ void ToyToAffineLoweringPass::runOnFunction() { // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index f3857f35e25c95..14f68e6176b5c4 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -301,7 +301,8 @@ void ToyToAffineLoweringPass::runOnFunction() { // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 19bf27e1864d18..a04b3ecd4daee4 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -200,7 +200,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index 92fd246a135886..6daccaef7cd3bf 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -302,7 +302,8 @@ void ToyToAffineLoweringPass::runOnFunction() { // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 19bf27e1864d18..a04b3ecd4daee4 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -200,7 +200,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a3e369438c8a88..566ae79d2515a6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -17,8 +17,8 @@ #include "llvm/ADT/SmallBitVector.h" namespace mlir { - class BufferizeTypeConverter; +class FrozenRewritePatternList; namespace linalg { @@ -844,8 +844,8 @@ class TensorCastOpConverter //===----------------------------------------------------------------------===// /// Helper function to allow applying rewrite patterns, interleaved with more /// global transformations, in a staged fashion: -/// 1. the first stage consists of a list of OwningRewritePatternList. Each -/// OwningRewritePatternList in this list is applied once, in order. +/// 1. the first stage consists of a list of FrozenRewritePatternList. Each +/// FrozenRewritePatternList in this list is applied once, in order. /// 2. the second stage consists of a single OwningRewritePattern that is /// applied greedily until convergence. /// 3. the third stage consists of applying a lambda, generally used for @@ -853,8 +853,8 @@ class TensorCastOpConverter /// transformations where patterns can be ordered and applied at a finer /// granularity than a sequence of traditional compiler passes. LogicalResult applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const OwningRewritePatternList &stage2Patterns, + Operation *op, ArrayRef stage1Patterns, + const FrozenRewritePatternList &stage2Patterns, function_ref stage3Lambda = nullptr); } // namespace linalg diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 74300f2a18828c..1295988050dfb0 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -440,6 +440,11 @@ class OwningRewritePatternList { PatternListT::size_type size() const { return patterns.size(); } void clear() { patterns.clear(); } + /// Take ownership of the patterns held by this list. + std::vector> takePatterns() { + return std::move(patterns); + } + //===--------------------------------------------------------------------===// // Pattern Insertion //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h new file mode 100644 index 00000000000000..fb2657d9923282 --- /dev/null +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h @@ -0,0 +1,38 @@ +//===- FrozenRewritePatternList.h - FrozenRewritePatternList ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H +#define MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +/// This class represents a frozen set of patterns that can be processed by a +/// pattern applicator. This class is designed to enable caching pattern lists +/// such that they need not be continuously recomputed. +class FrozenRewritePatternList { + using PatternListT = std::vector>; + +public: + /// Freeze the patterns held in `patterns`, and take ownership. + FrozenRewritePatternList(OwningRewritePatternList &&patterns); + + /// Return the patterns held by this list. + iterator_range> + getPatterns() const { + return llvm::make_pointee_range(patterns); + } + +private: + /// The patterns held by this list. + std::vector> patterns; +}; + +} // end namespace mlir + +#endif // MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h index be5911966e06d2..c1920b4ab2ae20 100644 --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -14,7 +14,7 @@ #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H #define MLIR_REWRITE_PATTERNAPPLICATOR_H -#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternList.h" namespace mlir { class PatternRewriter; @@ -29,8 +29,8 @@ class PatternApplicator { /// `impossibleToMatch`. using CostModel = function_ref; - explicit PatternApplicator(const OwningRewritePatternList &owningPatternList) - : owningPatternList(owningPatternList) {} + explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList) + : frozenPatternList(frozenPatternList) {} /// Attempt to match and rewrite the given op with any pattern, allowing a /// predicate to decide if a pattern can be applied or not, and hooks for if @@ -71,13 +71,12 @@ class PatternApplicator { function_ref onSuccess); /// The list that owns the patterns used within this applicator. - const OwningRewritePatternList &owningPatternList; - + const FrozenRewritePatternList &frozenPatternList; /// The set of patterns to match for each operation, stable sorted by benefit. - DenseMap> patterns; + DenseMap> patterns; /// The set of patterns that may match against any operation type, stable /// sorted by benefit. - SmallVector anyOpPatterns; + SmallVector anyOpPatterns; }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 8bffb9649d1f84..e02cf8fe4c0a28 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -13,9 +13,7 @@ #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_ #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Rewrite/FrozenRewritePatternList.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringMap.h" @@ -805,11 +803,11 @@ class ConversionTarget { /// the `unconvertedOps` set will not necessarily be complete.) LLVM_NODISCARD LogicalResult applyPartialConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps = nullptr); LLVM_NODISCARD LogicalResult applyPartialConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps = nullptr); /// Apply a complete conversion on the given operations, and all nested @@ -818,10 +816,10 @@ applyPartialConversion(Operation *op, ConversionTarget &target, /// within 'ops'. LLVM_NODISCARD LogicalResult applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); LLVM_NODISCARD LogicalResult applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully @@ -833,11 +831,11 @@ applyFullConversion(Operation *op, ConversionTarget &target, /// the regions nested within 'ops'. LLVM_NODISCARD LogicalResult applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps); LLVM_NODISCARD LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index ab88f6a1e87147..1306f25b229842 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -14,7 +14,7 @@ #ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ #define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ -#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternList.h" namespace mlir { @@ -32,11 +32,11 @@ namespace mlir { /// LogicalResult applyPatternsAndFoldGreedily(Operation *op, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns @@ -44,7 +44,7 @@ applyPatternsAndFoldGreedily(MutableArrayRef regions, /// was folded away or erased as a result of becoming dead. Note: This does not /// apply any patterns recursively to the regions of `op`. LogicalResult applyOpPatternsAndFold(Operation *op, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, bool *erased = nullptr); } // end namespace mlir diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp index a8c483430fce8f..ecda71cfb5e52c 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -179,7 +179,8 @@ void ConvertAVX512ToLLVMPass::runOnOperation() { target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) { + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 8a12b1eaf96d72..58f44b6ed20786 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -679,7 +679,8 @@ class LowerAffinePass : public ConvertAffineToStandardBase { ConversionTarget target(getContext()); target .addLegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 13498e821132a2..9a99bf00c08c10 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -723,7 +723,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() { target.addDynamicallyLegalOp( [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } } // namespace diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp index 9d4c0c32dc82c6..ae112d516d5ab5 100644 --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -240,7 +240,8 @@ void GpuToLLVMConversionPass::runOnOperation() { populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 2837ec14ef611f..0dd1f982be96d0 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -124,17 +124,16 @@ struct LowerGpuOpsToNVVMOpsPass return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); }); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns, llvmPatterns; // Apply in-dialect lowering first. In-dialect lowering will replace ops // which need to be lowered further, which is not supported by a single // conversion pass. populateGpuRewritePatterns(m.getContext(), patterns); - applyPatternsAndFoldGreedily(m, patterns); - patterns.clear(); + applyPatternsAndFoldGreedily(m, std::move(patterns)); - populateStdToLLVMConversionPatterns(converter, patterns); - populateGpuToNVVMConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, llvmPatterns); + populateGpuToNVVMConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); target.addIllegalDialect(); target.addIllegalOp(); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns))) + if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 62746663071f7b..489891e06fcb5a 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -59,16 +59,15 @@ struct LowerGpuOpsToROCDLOpsPass /*useAlignedAlloc =*/false}; LLVMTypeConverter converter(m.getContext(), options); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns, llvmPatterns; populateGpuRewritePatterns(m.getContext(), patterns); - applyPatternsAndFoldGreedily(m, patterns); - patterns.clear(); + applyPatternsAndFoldGreedily(m, std::move(patterns)); - populateVectorToLLVMConversionPatterns(converter, patterns); - populateVectorToROCDLConversionPatterns(converter, patterns); - populateStdToLLVMConversionPatterns(converter, patterns); - populateGpuToROCDLConversionPatterns(converter, patterns); + populateVectorToLLVMConversionPatterns(converter, llvmPatterns); + populateVectorToROCDLConversionPatterns(converter, llvmPatterns); + populateStdToLLVMConversionPatterns(converter, llvmPatterns); + populateGpuToROCDLConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); target.addIllegalDialect(); target.addIllegalOp(); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns))) + if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index ccaecd9c918920..083f992bda95b9 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -64,7 +64,7 @@ void GPUToSPIRVPass::runOnOperation() { populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); - if (failed(applyFullConversion(kernelModules, *target, patterns))) + if (failed(applyFullConversion(kernelModules, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 4f83297ee03124..50d176418600c4 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -335,7 +335,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() { LLVMConversionTarget target(getContext()); target.addLegalOp(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp index cc938c8c159400..d0b724b0484c39 100644 --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -41,7 +41,7 @@ void LinalgToSPIRVPass::runOnOperation() { typeConverter.isLegal(&op.getBody()); }); - if (failed(applyFullConversion(module, *target, patterns))) + if (failed(applyFullConversion(module, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index d64e6f9947c76f..5037f50b1f7c32 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -211,7 +211,7 @@ void ConvertLinalgToStandardPass::runOnOperation() { target.addLegalOp(); OwningRewritePatternList patterns; populateLinalgToStandardConversionPatterns(patterns, &getContext()); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 419d4bea30bd36..cfb553da407ce7 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -67,7 +67,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() { [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); }); target.addLegalOp(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp index ac07626abcddf8..d04a773939b78d 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -54,7 +54,8 @@ struct ParallelLoopToGpuPass target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp index 14f365f95ee5a5..625077a28aaca7 100644 --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -412,7 +412,8 @@ void SCFToStandardPass::runOnOperation() { ConversionTarget target(getContext()); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index a850c9badc8d50..bb5580ac37ff2d 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -290,7 +290,7 @@ class LowerHostCodeToLLVM ConversionTarget target(*context); target.addLegalDialect(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); // Finally, modify the kernel function in SPIR-V modules to avoid symbolic diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp index 329989b9e79552..cea0e76221e527 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -52,7 +52,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() { // conversion. target.addLegalOp(); target.addLegalOp(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 7f04d5f4c0b249..a63322522e642f 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -508,7 +508,7 @@ void ConvertShapeToStandardPass::runOnOperation() { // Apply conversion. auto module = getOperation(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 5a6c50f2a54954..2fcc8eb8f6fef1 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3788,7 +3788,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase { populateStdToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(m, target, patterns))) + if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), StringAttr::get(this->dataLayout, m.getContext())); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index 71208c719a191f..aea78523ca01ff 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -40,9 +40,8 @@ void ConvertStandardToSPIRVPass::runOnOperation() { populateStandardToSPIRVPatterns(context, typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); - if (failed(applyPartialConversion(module, *target, patterns))) { + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) return signalPassFailure(); - } } std::unique_ptr> diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 13d74088ee9c9d..75b6ce368e6f4a 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -203,7 +203,8 @@ void SPIRVLegalization::runOnOperation() { OwningRewritePatternList patterns; auto *context = &getContext(); populateStdLegalizationPatternsForSPIRVLowering(context, patterns); - applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); + applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)); } std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index cab9526b5c7f8e..5e4272e4ea3f50 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1619,7 +1619,7 @@ void LowerVectorToLLVMPass::runOnOperation() { populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } // Convert to the LLVM IR dialect. @@ -1632,7 +1632,8 @@ void LowerVectorToLLVMPass::runOnOperation() { populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp index f699ef054d1aeb..f61f896feae265 100644 --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -173,7 +173,8 @@ void LowerVectorToROCDLPass::runOnOperation() { LLVMConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 70bacf39710950..e9d05298e1683f 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -710,7 +710,7 @@ struct ConvertVectorToSCFPass auto *context = getFunction().getContext(); populateVectorToSCFConversionPatterns( patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll)); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 05949fb5991048..ad26118dc64322 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -109,7 +109,7 @@ void LowerVectorToSPIRVPass::runOnOperation() { target->addLegalOp(); target->addLegalOp(); - if (failed(applyFullConversion(module, *target, patterns))) + if (failed(applyFullConversion(module, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index 20d3b3e9ccb9bc..1fce7d9519c00d 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -228,6 +228,7 @@ void AffineDataCopyGeneration::runOnFunction() { OwningRewritePatternList patterns; AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); + FrozenRewritePatternList frozenPatterns(std::move(patterns)); for (Operation *op : copyOps) - applyOpPatternsAndFold(op, std::move(patterns)); + applyOpPatternsAndFold(op, frozenPatterns); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index 4d50aa1ba1e5d2..eb6d7b2cb1a8fc 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -83,6 +83,7 @@ void SimplifyAffineStructures::runOnFunction() { AffineForOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); + FrozenRewritePatternList frozenPatterns(std::move(patterns)); func.walk([&](Operation *op) { for (auto attr : op->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) @@ -94,6 +95,6 @@ void SimplifyAffineStructures::runOnFunction() { // The simplification of the attribute will likely simplify the op. Try to // fold / apply canonicalization patterns when we have affine dialect ops. if (isa(op)) - applyOpPatternsAndFold(op, patterns); + applyOpPatternsAndFold(op, frozenPatterns); }); } diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 8f8cac9fe57e3c..3b1be4dd5eac44 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -159,7 +159,8 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { OwningRewritePatternList patterns; AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); bool erased; - applyOpPatternsAndFold(ifOp, patterns, &erased); + FrozenRewritePatternList frozenPatterns(std::move(patterns)); + applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); if (erased) { if (folded) *folded = true; @@ -189,7 +190,7 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { // a sequence of affine.fors that are all perfectly nested). applyPatternsAndFoldGreedily( hoistedIfOp.getParentWithTrait(), - std::move(patterns)); + frozenPatterns); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp index 7b679755e2a779..2f70063957b872 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -350,7 +350,8 @@ struct LinalgBufferizePass : public LinalgBufferizeBase { populateWithBufferizeOpConversionPatterns( &context, converter, patterns); - if (failed(applyFullConversion(this->getOperation(), target, patterns))) + if (failed(applyFullConversion(this->getOperation(), target, + std::move(patterns)))) this->signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp index 5cf17855e3560a..bc86dcd9e05011 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -31,7 +31,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { // Emplace patterns one at a time while also maintaining a simple chained // state transition. unsigned stepCount = 0; - SmallVector stage1Patterns; + SmallVector stage1Patterns; auto zeroState = Identifier::get(std::to_string(stepCount), context); auto currentState = zeroState; for (const std::unique_ptr &t : transformationSequence) { @@ -60,7 +60,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { hoistRedundantVectorTransfers(cast(op)); return success(); }; - linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns, + linalg::applyStagedPatterns(func, stage1Patterns, std::move(stage2Patterns), stage3Transforms); //===--------------------------------------------------------------------===// @@ -73,7 +73,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { OwningRewritePatternList patterns; patterns.insert( context, vectorTransformsOptions); - applyPatternsAndFoldGreedily(module, patterns); + applyPatternsAndFoldGreedily(module, std::move(patterns)); // Programmatic controlled lowering of vector.contract only. OwningRewritePatternList vectorContractLoweringPatterns; @@ -81,13 +81,14 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { .insert( vectorTransformsOptions, context); - applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns); + applyPatternsAndFoldGreedily(module, + std::move(vectorContractLoweringPatterns)); // Programmatic controlled lowering of vector.transfer only. OwningRewritePatternList vectorToLoopsPatterns; populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, vectorToSCFOptions); - applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns); + applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns)); // Ensure we drop the marker in the end. module.walk([](LinalgOp op) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 4104da6965f29f..b2c649be9d921a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -518,7 +518,7 @@ struct LinalgFoldUnitExtentDimsPass FoldUnitDimLoops>(context); else populateLinalgFoldUnitExtentDimsPatterns(context, patterns); - applyPatternsAndFoldGreedily(funcOp.getBody(), patterns); + applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 139de8cc46517b..f2a3fb7d7766f5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -913,7 +913,7 @@ struct FusionOfTensorOpsPass OwningRewritePatternList patterns; Operation *op = getOperation(); populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); - applyPatternsAndFoldGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -926,7 +926,7 @@ struct FoldReshapeOpsByLinearizationPass OwningRewritePatternList patterns; Operation *op = getOperation(); populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns); - applyPatternsAndFoldGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 368f4f2c66dd2e..6c46dbf07accf1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -593,7 +593,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); // Just apply the patterns greedily. - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 0cfffa79f73c94..b4809ac4d7c463 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -595,7 +595,7 @@ static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType, MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; insertTilingPatterns(patterns, options, ctx); - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); applyPatternsAndFoldGreedily(funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); // Drop the marker. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 8638d705d1323e..836cc28e0a47f3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -257,8 +257,8 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( } LogicalResult mlir::linalg::applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const OwningRewritePatternList &stage2Patterns, + Operation *op, ArrayRef stage1Patterns, + const FrozenRewritePatternList &stage2Patterns, function_ref stage3Lambda) { unsigned iteration = 0; (void)iteration; diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp index 0879b73846d7a6..d98116d41cfc9b 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -96,7 +96,7 @@ void ConvertConstPass::runOnFunction() { auto func = getFunction(); auto *context = &getContext(); patterns.insert(context); - applyPatternsAndFoldGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, std::move(patterns)); } std::unique_ptr> mlir::quant::createConvertConstPass() { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp index 055e4759b87b59..2eb1a90556ff97 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -129,7 +129,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { auto ctx = func.getContext(); patterns.insert( ctx, &hadFailure); - applyPatternsAndFoldGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp index 23cf72f6ed2a00..7cf0dfabd91740 100644 --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -30,7 +30,7 @@ struct SCFBufferizePass : public SCFBufferizeBase { populateBufferizeMaterializationLegality(target); populateSCFStructuralTypeConversionsAndLegality(context, typeConverter, patterns, target); - if (failed(applyPartialConversion(func, target, patterns))) + if (failed(applyPartialConversion(func, target, std::move(patterns)))) return signalPassFailure(); }; }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index 5051b54f532fa8..53160427cf39c2 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -107,12 +107,10 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() { // TODO: Change the type for the indirect users such as spv.Load, spv.Store, // spv.FunctionCall and so on. - - for (auto spirvModule : module.getOps()) { - if (failed(applyFullConversion(spirvModule, target, patterns))) { + FrozenRewritePatternList frozenPatterns(std::move(patterns)); + for (auto spirvModule : module.getOps()) + if (failed(applyFullConversion(spirvModule, target, frozenPatterns))) signalPassFailure(); - } - } } std::unique_ptr> diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 000fa9dd2d8fc6..24679e4d523094 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -262,7 +262,7 @@ void LowerABIAttributesPass::runOnOperation() { return op->getDialect()->getNamespace() == spirv::SPIRVDialect::getDialectNamespace(); }); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) return signalPassFailure(); // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp index 20cd960e040fe0..1b51f330be6676 100644 --- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp @@ -26,7 +26,8 @@ struct ShapeBufferizePass : public ShapeBufferizeBase { populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter, patterns, target); - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp index 41c372c67f6a87..4337909e12e9ca 100644 --- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp @@ -49,7 +49,7 @@ class RemoveShapeConstraintsPass OwningRewritePatternList patterns; populateRemoveShapeConstraintsPatterns(patterns, &ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp index ff74ce069e407e..49af5d7ce9a203 100644 --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -67,7 +67,8 @@ void ShapeToShapeLowering::runOnFunction() { ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + if (failed(mlir::applyPartialConversion(getFunction(), target, + std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index e5b71f0fce7540..a1b1f0a6499210 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -148,8 +148,8 @@ struct StdBufferizePass : public StdBufferizeBase { populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp(); - - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp index e603c76ad71d20..8513880be2e187 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp @@ -81,7 +81,8 @@ struct ExpandAtomic : public ExpandAtomicBase { return op.kind() != AtomicRMWKind::maxf && op.kind() != AtomicRMWKind::minf; }); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + if (failed(mlir::applyPartialConversion(getFunction(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt index aa6632b63f0929..e37b9c31dab92c 100644 --- a/mlir/lib/Rewrite/CMakeLists.txt +++ b/mlir/lib/Rewrite/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRRewrite + FrozenRewritePatternList.cpp PatternApplicator.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp new file mode 100644 index 00000000000000..d0e45184ac28a0 --- /dev/null +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -0,0 +1,19 @@ +//===- FrozenRewritePatternList.cpp - Frozen Pattern List -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Rewrite/FrozenRewritePatternList.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// FrozenRewritePatternList +//===----------------------------------------------------------------------===// + +FrozenRewritePatternList::FrozenRewritePatternList( + OwningRewritePatternList &&patterns) + : patterns(patterns.takePatterns()) {} diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index f9c0cbfed88019..5d6ae51e8eeba5 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -22,28 +22,28 @@ void PatternApplicator::applyCostModel(CostModel model) { // Separate patterns by root kind to simplify lookup later on. patterns.clear(); anyOpPatterns.clear(); - for (const auto &pat : owningPatternList) { + for (const auto &pat : frozenPatternList.getPatterns()) { // If the pattern is always impossible to match, just ignore it. - if (pat->getBenefit().isImpossibleToMatch()) { + if (pat.getBenefit().isImpossibleToMatch()) { LLVM_DEBUG({ llvm::dbgs() - << "Ignoring pattern '" << pat->getRootKind() + << "Ignoring pattern '" << pat.getRootKind() << "' because it is impossible to match (by pattern benefit)\n"; }); continue; } - if (Optional opName = pat->getRootKind()) - patterns[*opName].push_back(pat.get()); + if (Optional opName = pat.getRootKind()) + patterns[*opName].push_back(&pat); else - anyOpPatterns.push_back(pat.get()); + anyOpPatterns.push_back(&pat); } // Sort the patterns using the provided cost model. - llvm::SmallDenseMap benefits; - auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) { + llvm::SmallDenseMap benefits; + auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { return benefits[lhs] > benefits[rhs]; }; - auto processPatternList = [&](SmallVectorImpl &list) { + auto processPatternList = [&](SmallVectorImpl &list) { // Special case for one pattern in the list, which is the most common case. if (list.size() == 1) { if (model(*list.front()).isImpossibleToMatch()) { @@ -59,7 +59,7 @@ void PatternApplicator::applyCostModel(CostModel model) { // Collect the dynamic benefits for the current pattern list. benefits.clear(); - for (RewritePattern *pat : list) + for (const Pattern *pat : list) benefits.try_emplace(pat, model(*pat)); // Sort patterns with highest benefit first, and remove those that are @@ -81,8 +81,8 @@ void PatternApplicator::applyCostModel(CostModel model) { void PatternApplicator::walkAllPatterns( function_ref walk) { - for (auto &it : owningPatternList) - walk(*it); + for (auto &it : frozenPatternList.getPatterns()) + walk(it); } LogicalResult PatternApplicator::matchAndRewrite( @@ -91,7 +91,7 @@ LogicalResult PatternApplicator::matchAndRewrite( function_ref onFailure, function_ref onSuccess) { // Check to see if there are patterns matching this specific operation type. - MutableArrayRef opPatterns; + MutableArrayRef opPatterns; auto patternIt = patterns.find(op->getName()); if (patternIt != patterns.end()) opPatterns = patternIt->second; @@ -104,7 +104,7 @@ LogicalResult PatternApplicator::matchAndRewrite( auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); while (opIt != opE && anyIt != anyE) { // Try to match the pattern providing the most benefit. - RewritePattern *pattern; + const RewritePattern *pattern; if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) pattern = *(opIt++); else @@ -118,7 +118,7 @@ LogicalResult PatternApplicator::matchAndRewrite( // If we break from the loop, then only one of the ranges can still have // elements. Loop over both without checking given that we don't need to // interleave anymore. - for (RewritePattern *pattern : llvm::concat( + for (const RewritePattern *pattern : llvm::concat( llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, onSuccess))) diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index de435440607f52..70208a89debf25 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -32,7 +32,7 @@ struct Canonicalizer : public CanonicalizerBase { op->getCanonicalizationPatterns(patterns, context); Operation *op = getOperation(); - applyPatternsAndFoldGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index ff104071f444f9..66b29b6dcd4b71 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -503,7 +503,7 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, /// canonicalization patterns. static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context, - const OwningRewritePatternList &canonPatterns) { + const FrozenRewritePatternList &canonPatterns) { // Collect the sets of nodes to canonicalize. SmallVector nodesToCanonicalize; for (auto *node : currentSCC) { @@ -574,7 +574,7 @@ struct InlinerPass : public InlinerBase { /// the inlining of newly devirtualized calls. void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context, - const OwningRewritePatternList &canonPatterns); + const FrozenRewritePatternList &canonPatterns); }; } // end anonymous namespace @@ -596,13 +596,14 @@ void InlinerPass::runOnOperation() { OwningRewritePatternList canonPatterns; for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(canonPatterns, context); + FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns)); // Run the inline transform in post-order over the SCCs in the callgraph. SymbolTableCollection symbolTable; Inliner inliner(context, cg, symbolTable); CGUseList useList(getOperation(), cg, symbolTable); runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { - inlineSCC(inliner, useList, scc, context, canonPatterns); + inlineSCC(inliner, useList, scc, context, frozenCanonPatterns); }); // After inlining, make sure to erase any callables proven to be dead. @@ -611,7 +612,7 @@ void InlinerPass::runOnOperation() { void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context, - const OwningRewritePatternList &canonPatterns) { + const FrozenRewritePatternList &canonPatterns) { // If we successfully inlined any calls, run some simplifications on the // nodes of the scc. Continue attempting to inline until we reach a fixed // point, or a maximum iteration count. We canonicalize here as it may diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 1d649fd3a02ec7..d1a42fe0720db1 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1459,7 +1459,7 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(ConversionTarget &targetInfo, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1555,7 +1555,7 @@ class OperationLegalizer { } // namespace OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, - const OwningRewritePatternList &patterns) + const FrozenRewritePatternList &patterns) : target(targetInfo), applicator(patterns) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. @@ -2078,7 +2078,7 @@ enum OpConversionMode { // conversion mode. struct OperationConverter { explicit OperationConverter(ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, OpConversionMode mode, DenseSet *trackedOps = nullptr) : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} @@ -2672,7 +2672,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const LogicalResult mlir::applyPartialConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Partial, unconvertedOps); @@ -2680,7 +2680,7 @@ mlir::applyPartialConversion(ArrayRef ops, } LogicalResult mlir::applyPartialConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps) { return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, unconvertedOps); @@ -2691,13 +2691,13 @@ mlir::applyPartialConversion(Operation *op, ConversionTarget &target, /// operation fails. LogicalResult mlir::applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); return opConverter.convertOperations(ops); } LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { return applyFullConversion(llvm::makeArrayRef(op), target, patterns); } @@ -2710,7 +2710,7 @@ mlir::applyFullConversion(Operation *op, ConversionTarget &target, LogicalResult mlir::applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); @@ -2718,7 +2718,7 @@ mlir::applyAnalysisConversion(ArrayRef ops, } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, convertedOps); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 199cfbf3d1f097..bbe3ac57d91c30 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,7 +37,7 @@ namespace { class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const OwningRewritePatternList &patterns) + const FrozenRewritePatternList &patterns) : PatternRewriter(ctx), matcher(patterns), folder(ctx) { worklist.reserve(64); @@ -219,13 +219,13 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, /// LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { return applyPatternsAndFoldGreedily(op->getRegions(), patterns); } /// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { if (regions.empty()) return success(); @@ -259,7 +259,7 @@ namespace { class OpPatternRewriteDriver : public PatternRewriter { public: explicit OpPatternRewriteDriver(MLIRContext *ctx, - const OwningRewritePatternList &patterns) + const FrozenRewritePatternList &patterns) : PatternRewriter(ctx), matcher(patterns), folder(ctx) { // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); @@ -343,7 +343,7 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, /// folding. `erased` is set to true if the op is erased as a result of being /// folded, replaced, or dead. LogicalResult mlir::applyOpPatternsAndFold( - Operation *op, const OwningRewritePatternList &patterns, bool *erased) { + Operation *op, const FrozenRewritePatternList &patterns, bool *erased) { // Start the pattern driver. OpPatternRewriteDriver driver(op->getContext(), patterns); bool opErased; diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp index 03c425d6d9062e..e211daff725a21 100644 --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -144,7 +144,7 @@ void ConvertToTargetEnv::runOnFunction() { ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot>(context); - if (failed(applyPartialConversion(fn, *target, patterns))) + if (failed(applyPartialConversion(fn, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 2099b368eddb5d..5958105090a44f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -83,7 +83,7 @@ struct TestPatternDriver : public PassWrapper { // Verify named pattern is generated with expected name. patterns.insert(&getContext()); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; } // end anonymous namespace @@ -601,7 +601,7 @@ struct TestLegalizePatternDriver // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, patterns, + (void)applyPartialConversion(getOperation(), target, std::move(patterns), &unlegalizedOps); // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) @@ -616,7 +616,7 @@ struct TestLegalizePatternDriver return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, patterns); + (void)applyFullConversion(getOperation(), target, std::move(patterns)); return; } @@ -625,8 +625,8 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet legalizedOps; - if (failed(applyAnalysisConversion(getOperation(), target, patterns, - legalizedOps))) + if (failed(applyAnalysisConversion(getOperation(), target, + std::move(patterns), legalizedOps))) return signalPassFailure(); // Emit remarks for each legalizable operation. @@ -704,7 +704,8 @@ struct TestRemappedValue return std::distance(op->operand_begin(), op->operand_end()) > 1; }); - if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { + if (failed(mlir::applyFullConversion(getFunction(), target, + std::move(patterns)))) { signalPassFailure(); } } @@ -737,7 +738,8 @@ struct TestUnknownRootOpDriver mlir::ConversionTarget target(getContext()); target.addIllegalDialect(); - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; @@ -833,7 +835,8 @@ struct TestTypeConversionDriver mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; @@ -939,7 +942,7 @@ struct TestMergeBlocksPatternDriver }); DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, patterns, + (void)applyPartialConversion(getOperation(), target, std::move(patterns), &unlegalizedOps); for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp index b78884a5479d6a..dbd05aee7ac1b1 100644 --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -32,7 +32,7 @@ OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold( namespace { struct TestTraitFolder : public PassWrapper { void runOnFunction() override { - applyPatternsAndFoldGreedily(getFunction(), {}); + applyPatternsAndFoldGreedily(getFunction(), OwningRewritePatternList()); } }; } // end anonymous namespace diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index f4931822c65b3c..9592ac3ba48a85 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -232,7 +232,8 @@ struct TestBufferPlacementPreparationPass OwningRewritePatternList patterns; populateTensorLinalgToBufferLinalgConversionPattern(&context, converter, patterns); - if (failed(applyFullConversion(this->getOperation(), target, patterns))) + if (failed(applyFullConversion(this->getOperation(), target, + std::move(patterns)))) this->signalPassFailure(); }; }; diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp index 5c16677c884991..e5aa49ec931dcf 100644 --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -60,6 +60,8 @@ void TestConvVectorization::runOnOperation() { SmallVector stage1Patterns; linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); + SmallVector frozenStage1Patterns; + llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); OwningRewritePatternList stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); @@ -78,8 +80,8 @@ void TestConvVectorization::runOnOperation() { return success(); }; - linalg::applyStagedPatterns(module, stage1Patterns, stage2Patterns, - stage3Transforms); + linalg::applyStagedPatterns(module, frozenStage1Patterns, + std::move(stage2Patterns), stage3Transforms); //===--------------------------------------------------------------------===// // Post staged patterns transforms diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp index 67c45f1c28b7d8..7beba01ca9aebc 100644 --- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -58,9 +58,8 @@ class TestConvertCallOp target.addIllegalDialect(); target.addIllegalDialect(); - if (failed(applyPartialConversion(m, target, patterns))) { + if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); - } } }; diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp index ab485e58e895c4..c33625bf24af02 100644 --- a/mlir/test/lib/Transforms/TestExpandTanh.cpp +++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp @@ -26,7 +26,7 @@ struct TestExpandTanhPass void TestExpandTanhPass::runOnFunction() { OwningRewritePatternList patterns; populateExpandTanhPattern(patterns, &getContext()); - applyPatternsAndFoldGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Transforms/TestGpuRewrite.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp index eaa7149fa99409..a7155d4ff3ebc8 100644 --- a/mlir/test/lib/Transforms/TestGpuRewrite.cpp +++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp @@ -26,7 +26,7 @@ struct TestGpuRewritePass void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); - applyPatternsAndFoldGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp index 0fa4e22ebf37ba..4dfb653ac85852 100644 --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -97,7 +97,7 @@ static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) { LinalgDependenceGraph dependenceGraph = LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); fillFusionPatterns(context, dependenceGraph, fusionPatterns); - applyPatternsAndFoldGreedily(funcOp, fusionPatterns); + applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); } void TestLinalgFusionTransforms::runOnFunction() { diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index cc8c09d89a03ab..a861b193602cce 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -208,7 +208,7 @@ static void applyPatterns(FuncOp funcOp) { LinalgMarker(Identifier::get("_promote_views_aligned_", ctx), Identifier::get("_views_aligned_promoted_", ctx))); - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the marker. funcOp.walk([](LinalgOp op) { @@ -431,16 +431,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp, fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), stage1Patterns); } - OwningRewritePatternList stage2Patterns = + SmallVector frozenStage1Patterns; + llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); + FrozenRewritePatternList stage2Patterns = getLinalgTilingCanonicalizationPatterns(ctx); - applyStagedPatterns(funcOp, stage1Patterns, stage2Patterns); + applyStagedPatterns(funcOp, frozenStage1Patterns, std::move(stage2Patterns)); } static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { OwningRewritePatternList forwardPattern; forwardPattern.insert(funcOp.getContext()); forwardPattern.insert(funcOp.getContext()); - applyPatternsAndFoldGreedily(funcOp, forwardPattern); + applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); } static void applyContractionToVectorPatterns(FuncOp funcOp) { @@ -451,16 +453,18 @@ static void applyContractionToVectorPatterns(FuncOp funcOp) { LinalgVectorizationPattern, LinalgVectorizationPattern, LinalgVectorizationPattern>(funcOp.getContext()); - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { OwningRewritePatternList foldPattern; foldPattern.insert(funcOp.getContext()); + FrozenRewritePatternList frozenPatterns(std::move(foldPattern)); + // Explicitly walk and apply the pattern locally to avoid more general folding // on the rest of the IR. - funcOp.walk([&foldPattern](AffineMinOp minOp) { - applyOpPatternsAndFold(minOp, foldPattern); + funcOp.walk([&frozenPatterns](AffineMinOp minOp) { + applyOpPatternsAndFold(minOp, frozenPatterns); }); } /// Apply transformations specified as patterns. @@ -475,13 +479,13 @@ void TestLinalgTransforms::runOnFunction() { if (testPromotionOptions) { OwningRewritePatternList patterns; fillPromotionCallBackPatterns(&getContext(), patterns); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testTileAndDistributionOptions) { OwningRewritePatternList patterns; fillTileAndDistributePatterns(&getContext(), patterns); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testPatterns) diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 989e8cda34f96d..20903b30648069 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -32,7 +32,7 @@ struct TestVectorToVectorConversion ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -41,7 +41,7 @@ struct TestVectorSlicesConversion void runOnFunction() override { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -78,7 +78,7 @@ struct TestVectorContractionConversion VectorTransformsOptions options{lowering}; patterns.insert(options, &getContext()); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } @@ -94,7 +94,7 @@ struct TestVectorContractionConversion return failure(); return success(); }); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } @@ -108,7 +108,7 @@ struct TestVectorContractionConversion transposeLowering = VectorTransposeLowering::Flat; VectorTransformsOptions options{contractLowering, transposeLowering}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -145,7 +145,7 @@ struct TestVectorUnrollingPatterns } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } Option unrollBasedOnType{ @@ -181,7 +181,7 @@ struct TestVectorDistributePatterns }); patterns.insert(ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -199,7 +199,7 @@ struct TestVectorTransferUnrollingPatterns ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -228,7 +228,7 @@ struct TestVectorTransferFullPartialSplitPatterns else options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); patterns.insert(ctx, options); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };