diff --git a/mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h index 642f704c14430..fa078140eb515 100644 --- a/mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_FUNC_TRANSFORMS_FUNCCONVERSIONS_H_ #define MLIR_DIALECT_FUNC_TRANSFORMS_FUNCCONVERSIONS_H_ +#include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -29,7 +30,8 @@ class RewritePatternSet; /// Add a pattern to the given pattern list to convert the operand and result /// types of a CallOp with the given type converter. void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, - const TypeConverter &converter); + const TypeConverter &converter, + PatternBenefit benefit = 1); /// Add a pattern to the given pattern list to rewrite branch operations to use /// operands that have been legalized by the conversion framework. This can only @@ -43,7 +45,8 @@ void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, void populateBranchOpInterfaceTypeConversionPattern( RewritePatternSet &patterns, const TypeConverter &converter, function_ref - shouldConvertBranchOperand = nullptr); + shouldConvertBranchOperand = nullptr, + PatternBenefit benefit = 1); /// Return true if op is a BranchOpInterface op whose operands are all legal /// according to converter. @@ -53,7 +56,8 @@ bool isLegalForBranchOpInterfaceTypeConversionPattern( /// Add a pattern to the given pattern list to rewrite `return` ops to use /// operands that have been legalized by the conversion framework. void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, - const TypeConverter &converter); + const TypeConverter &converter, + PatternBenefit benefit = 1); /// For ReturnLike ops (except `return`), return True. If op is a `return` && /// returnOpAlwaysLegal is false, legalize op according to converter. Otherwise, diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index b6c8cdf2f495a..216401a80c9f8 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -65,8 +65,10 @@ struct CallOpSignatureConversion : public OpConversionPattern { } // namespace void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns, - const TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + const TypeConverter &converter, + PatternBenefit benefit) { + patterns.add(converter, patterns.getContext(), + benefit); } namespace { @@ -81,8 +83,9 @@ class BranchOpInterfaceTypeConversion BranchOpInterfaceTypeConversion( const TypeConverter &typeConverter, MLIRContext *ctx, - function_ref shouldConvertBranchOperand) - : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), + function_ref shouldConvertBranchOperand, + PatternBenefit benefit) + : OpInterfaceConversionPattern(typeConverter, ctx, benefit), shouldConvertBranchOperand(shouldConvertBranchOperand) {} LogicalResult @@ -135,9 +138,11 @@ class ReturnOpTypeConversion : public OpConversionPattern { void mlir::populateBranchOpInterfaceTypeConversionPattern( RewritePatternSet &patterns, const TypeConverter &typeConverter, - function_ref shouldConvertBranchOperand) { + function_ref shouldConvertBranchOperand, + PatternBenefit benefit) { patterns.add( - typeConverter, patterns.getContext(), shouldConvertBranchOperand); + typeConverter, patterns.getContext(), shouldConvertBranchOperand, + benefit); } bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( @@ -157,8 +162,10 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( } void mlir::populateReturnOpTypeConversionPattern( - RewritePatternSet &patterns, const TypeConverter &typeConverter) { - patterns.add(typeConverter, patterns.getContext()); + RewritePatternSet &patterns, const TypeConverter &typeConverter, + PatternBenefit benefit) { + patterns.add(typeConverter, patterns.getContext(), + benefit); } bool mlir::isLegalForReturnOpTypeConversionPattern(