From ade7071aa0138d9991f698c701ce7cad0dcfb084 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 19 Sep 2025 10:39:14 +0000 Subject: [PATCH] [mlir] Add optional PatternBenefit to function / SCF conversion patterns --- .../include/mlir/Dialect/SCF/Transforms/Patterns.h | 5 +++-- mlir/include/mlir/Transforms/DialectConversion.h | 12 +++++++----- .../SCF/Transforms/StructuralTypeConversions.cpp | 9 +++++---- mlir/lib/Transforms/Utils/DialectConversion.cpp | 14 ++++++++------ 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h index 00c8a5c0c517b..cfe68f61a7a42 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h @@ -51,12 +51,13 @@ class ForLoopPipeliningPattern : public OpRewritePattern { /// TypeConverter, but otherwise don't care what type conversions are happening. void populateSCFStructuralTypeConversionsAndLegality( const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); + ConversionTarget &target, PatternBenefit benefit = 1); /// Similar to `populateSCFStructuralTypeConversionsAndLegality` but does not /// populate the conversion target. void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Updates the ConversionTarget with dynamic legality of SCF operations based /// on the provided type converter. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 6ef649e8fc13a..ed7e2a08ebfd9 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -811,17 +811,19 @@ convertOpResultTypes(Operation *op, ValueRange operands, /// ops which use FunctionType to represent their type. void populateFunctionOpInterfaceTypeConversionPattern( StringRef functionLikeOpName, RewritePatternSet &patterns, - const TypeConverter &converter); + const TypeConverter &converter, PatternBenefit benefit = 1); template void populateFunctionOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, const TypeConverter &converter) { - populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(), - patterns, converter); + RewritePatternSet &patterns, const TypeConverter &converter, + PatternBenefit benefit = 1) { + populateFunctionOpInterfaceTypeConversionPattern( + FuncOpT::getOperationName(), patterns, converter, benefit); } void populateAnyFunctionOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, const TypeConverter &converter); + RewritePatternSet &patterns, const TypeConverter &converter, + PatternBenefit benefit = 1); //===----------------------------------------------------------------------===// // Conversion PatternRewriter diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 072bc501aa5c6..b0c781c7aff11 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -217,10 +217,11 @@ class ConvertConditionOpTypes : public OpConversionPattern { } // namespace void mlir::scf::populateSCFStructuralTypeConversions( - const TypeConverter &typeConverter, RewritePatternSet &patterns) { + const TypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { patterns.add( - typeConverter, patterns.getContext()); + typeConverter, patterns.getContext(), benefit); } void mlir::scf::populateSCFStructuralTypeConversionTarget( @@ -240,7 +241,7 @@ void mlir::scf::populateSCFStructuralTypeConversionTarget( void mlir::scf::populateSCFStructuralTypeConversionsAndLegality( const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - populateSCFStructuralTypeConversions(typeConverter, patterns); + ConversionTarget &target, PatternBenefit benefit) { + populateSCFStructuralTypeConversions(typeConverter, patterns, benefit); populateSCFStructuralTypeConversionTarget(typeConverter, target); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f7565cfb0e45e..ff1e31536cea7 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3802,8 +3802,9 @@ namespace { struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, MLIRContext *ctx, - const TypeConverter &converter) - : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} + const TypeConverter &converter, + PatternBenefit benefit) + : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef /*operands*/, @@ -3848,15 +3849,16 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands, void mlir::populateFunctionOpInterfaceTypeConversionPattern( StringRef functionLikeOpName, RewritePatternSet &patterns, - const TypeConverter &converter) { + const TypeConverter &converter, PatternBenefit benefit) { patterns.add( - functionLikeOpName, patterns.getContext(), converter); + functionLikeOpName, patterns.getContext(), converter, benefit); } void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, const TypeConverter &converter) { + RewritePatternSet &patterns, const TypeConverter &converter, + PatternBenefit benefit) { patterns.add( - converter, patterns.getContext()); + converter, patterns.getContext(), benefit); } //===----------------------------------------------------------------------===//