diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 5b605c165be60c..bf41f29749de3c 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -571,9 +571,11 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { &typeConverter.getContext(), typeConverter, benefit) {} -private: - /// Wrappers around the ConversionPattern methods that pass the derived op - /// type. + /// Wrappers around the RewritePattern methods that pass the derived op type. + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast(op), operands, rewriter); + } LogicalResult match(Operation *op) const final { return match(cast(op)); } @@ -582,10 +584,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), operands, rewriter); } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); - } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -605,6 +603,10 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { } return failure(); } + +private: + using ConvertToLLVMPattern::match; + using ConvertToLLVMPattern::matchAndRewrite; }; namespace LLVM { @@ -634,7 +636,6 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneConvertToLLVMPattern; -private: /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult @@ -654,7 +655,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern; -private: LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 1739cfa4a80c3e..d97b328cdc01a2 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -156,6 +156,17 @@ class RewritePattern : public Pattern { public: virtual ~RewritePattern() {} + /// Rewrite the IR rooted at the specified operation with the result of + /// this pattern, generating any new operations with the specified + /// builder. If an unexpected error is encountered (an internal + /// compiler error), it is emitted through the normal MLIR diagnostic + /// hooks and the IR is left in a valid state. + virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; + + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). + virtual LogicalResult match(Operation *op) const; + /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). If successful, this /// function will automatically perform the rewrite. @@ -172,18 +183,6 @@ class RewritePattern : public Pattern { /// Inherit the base constructors from `Pattern`. using Pattern::Pattern; - /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). - virtual LogicalResult match(Operation *op) const; - -private: - /// Rewrite the IR rooted at the specified operation with the result of - /// this pattern, generating any new operations with the specified - /// builder. If an unexpected error is encountered (an internal - /// compiler error), it is emitted through the normal MLIR diagnostic - /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; - /// An anchor for the virtual table. virtual void anchor(); }; @@ -192,14 +191,12 @@ class RewritePattern : public Pattern { /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. template -class OpRewritePattern : public RewritePattern { -public: +struct OpRewritePattern : public RewritePattern { /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) : RewritePattern(SourceOp::getOperationName(), benefit, context) {} -private: /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, PatternRewriter &rewriter) const final { rewrite(cast(op), rewriter); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index ecbb653f7ed9f8..e02cf8fe4c0a28 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -313,30 +313,6 @@ class TypeConverter { /// patterns of this type can only be used with the 'apply*' methods below. class ConversionPattern : public RewritePattern { public: - /// Return the type converter held by this pattern, or nullptr if the pattern - /// does not require type conversion. - TypeConverter *getTypeConverter() const { return typeConverter; } - -protected: - /// See `RewritePattern::RewritePattern` for information on the other - /// available constructors. - using RewritePattern::RewritePattern; - /// Construct a conversion pattern that matches an operation with the given - /// root name. This constructor allows for providing a type converter to use - /// within the pattern. - ConversionPattern(StringRef rootName, PatternBenefit benefit, - TypeConverter &typeConverter, MLIRContext *ctx) - : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} - /// Construct a conversion pattern that matches any operation type. This - /// constructor allows for providing a type converter to use within the - /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" - /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should - /// always be supplied here. - ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, - MatchAnyOpTypeTag tag) - : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} - -private: /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of the rewritten /// operand values that are passed to `op`, `rewriter` can be used to emit the @@ -347,10 +323,6 @@ class ConversionPattern : public RewritePattern { llvm_unreachable("unimplemented rewrite"); } - void rewrite(Operation *op, PatternRewriter &rewriter) const final { - llvm_unreachable("never called"); - } - /// Hook for derived classes to implement combined matching and rewriting. virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -365,17 +337,42 @@ class ConversionPattern : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + +protected: + /// See `RewritePattern::RewritePattern` for information on the other + /// available constructors. + using RewritePattern::RewritePattern; + /// Construct a conversion pattern that matches an operation with the given + /// root name. This constructor allows for providing a type converter to use + /// within the pattern. + ConversionPattern(StringRef rootName, PatternBenefit benefit, + TypeConverter &typeConverter, MLIRContext *ctx) + : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} + /// Construct a conversion pattern that matches any operation type. This + /// constructor allows for providing a type converter to use within the + /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, + MatchAnyOpTypeTag tag) + : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} + protected: /// An optional type converter for use by this pattern. TypeConverter *typeConverter = nullptr; + +private: + using RewritePattern::rewrite; }; /// OpConversionPattern is a wrapper around ConversionPattern that allows for /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. template -class OpConversionPattern : public ConversionPattern { -public: +struct OpConversionPattern : public ConversionPattern { OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -383,7 +380,6 @@ class OpConversionPattern : public ConversionPattern { : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter, context) {} -private: /// Wrappers around the ConversionPattern methods that pass the derived op /// type. void rewrite(Operation *op, ArrayRef operands, @@ -413,6 +409,9 @@ class OpConversionPattern : public ConversionPattern { rewrite(op, operands, rewriter); return success(); } + +private: + using ConversionPattern::matchAndRewrite; }; /// Add a pattern to the given pattern list to convert the signature of a FuncOp