diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h index b7fa8fc3848f2..7d816a8843371 100644 --- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h +++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h @@ -237,9 +237,7 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { virtual llvm::LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - llvm::SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + return dispatchTo1To1(*this, op, adaptor, rewriter); } private: diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 969154abe8830..19148a5d783f3 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -233,9 +233,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + return dispatchTo1To1(*this, op, adaptor, rewriter); } private: @@ -276,7 +274,7 @@ class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern { virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + return dispatchTo1To1(*this, op, operands, rewriter); } private: diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index e601c821e1e4e..220431e6ee2f1 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -521,8 +521,8 @@ class ConversionPattern : public RewritePattern { /// Hook for derived classes to implement combined matching and rewriting. /// This overload supports only 1:1 replacements. The 1:N overload is called - /// by the driver. By default, it calls this 1:1 overload or reports a fatal - /// error if 1:N replacements were found. + /// by the driver. By default, it calls this 1:1 overload or fails to match + /// if 1:N replacements were found. virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -534,7 +534,7 @@ class ConversionPattern : public RewritePattern { virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + return dispatchTo1To1(*this, op, operands, rewriter); } /// Attempt to match and rewrite the IR root at the specified operation. @@ -567,11 +567,26 @@ class ConversionPattern : public RewritePattern { /// try to extract the single value of each range to construct a the inputs /// for a 1:1 adaptor. /// - /// This function produces a fatal error if at least one range has 0 or - /// more than 1 value: "pattern 'name' does not support 1:N conversion" - SmallVector + /// Returns failure if at least one range has 0 or more than 1 value. + FailureOr> getOneToOneAdaptorOperands(ArrayRef operands) const; + /// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method + /// if possible and emit diagnostic with a failure return value otherwise. + /// 'self' should be '*this' of the derived-pattern and is used to dispatch + /// to the correct 'matchAndRewrite' method in the derived pattern. + template + static LogicalResult dispatchTo1To1(const SelfPattern &self, SourceOp op, + ArrayRef operands, + ConversionPatternRewriter &rewriter); + + /// Same as above, but accepts an adaptor as operand. + template + static LogicalResult dispatchTo1To1( + const SelfPattern &self, SourceOp op, + typename SourceOp::template GenericAdaptor> adaptor, + ConversionPatternRewriter &rewriter); + protected: /// An optional type converter for use by this pattern. const TypeConverter *typeConverter = nullptr; @@ -620,9 +635,7 @@ class OpConversionPattern : public ConversionPattern { virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + return dispatchTo1To1(*this, op, adaptor, rewriter); } private: @@ -666,7 +679,7 @@ class OpInterfaceConversionPattern : public ConversionPattern { virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + return dispatchTo1To1(*this, op, operands, rewriter); } private: @@ -865,6 +878,35 @@ class ConversionPatternRewriter final : public PatternRewriter { std::unique_ptr impl; }; +template +LogicalResult +ConversionPattern::dispatchTo1To1(const SelfPattern &self, SourceOp op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + FailureOr> oneToOneOperands = + self.getOneToOneAdaptorOperands(operands); + if (failed(oneToOneOperands)) + return rewriter.notifyMatchFailure(op, + "pattern '" + self.getDebugName() + + "' does not support 1:N conversion"); + return self.matchAndRewrite(op, *oneToOneOperands, rewriter); +} + +template +LogicalResult ConversionPattern::dispatchTo1To1( + const SelfPattern &self, SourceOp op, + typename SourceOp::template GenericAdaptor> adaptor, + ConversionPatternRewriter &rewriter) { + FailureOr> oneToOneOperands = + self.getOneToOneAdaptorOperands(adaptor.getOperands()); + if (failed(oneToOneOperands)) + return rewriter.notifyMatchFailure(op, + "pattern '" + self.getDebugName() + + "' does not support 1:N conversion"); + return self.matchAndRewrite( + op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter); +} + //===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 001c13e1ab08c..ff34a58965763 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2244,17 +2244,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// -SmallVector ConversionPattern::getOneToOneAdaptorOperands( +FailureOr> ConversionPattern::getOneToOneAdaptorOperands( ArrayRef operands) const { SmallVector oneToOneOperands; oneToOneOperands.reserve(operands.size()); for (ValueRange operand : operands) { if (operand.size() != 1) - llvm::report_fatal_error("pattern '" + getDebugName() + - "' does not support 1:N conversion"); + return failure(); + oneToOneOperands.push_back(operand.front()); } - return oneToOneOperands; + return std::move(oneToOneOperands); } LogicalResult diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 9a04da7904863..55d153db7f4bb 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -439,3 +439,24 @@ func.func @test_lookup_without_converter() { // expected-remark@+1 {{op 'func.return' is not legalizable}} return } + +// ----- +// expected-remark@-1 {{applyPartialConversion failed}} + +func.func @test_skip_1to1_pattern(%arg0: f32) { + // expected-error@+1 {{failed to legalize operation 'test.type_consumer'}} + "test.type_consumer"(%arg0) : (f32) -> () + return +} + +// ----- + +// Demonstrate that the pattern generally works, but only for 1:1 type +// conversions. + +// CHECK-LABEL: @test_working_1to1_pattern( +func.func @test_working_1to1_pattern(%arg0: f16) { + // CHECK-NEXT: "test.return"() : () -> () + "test.type_consumer"(%arg0) : (f16) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 657dfd2bac6ec..6300c5b0ca21c 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1386,6 +1386,23 @@ class TestMultiple1ToNReplacement : public ConversionPattern { } }; +/// Pattern that erases 'test.type_consumers' iff the input operand is the +/// result of a 1:1 type conversion. +/// Used to test correct skipping of 1:1 patterns in the 1:N case. +class TestTypeConsumerOpPattern + : public OpConversionPattern { +public: + TestTypeConsumerOpPattern(MLIRContext *ctx, const TypeConverter &converter) + : OpConversionPattern(converter, ctx) {} + + LogicalResult + matchAndRewrite(TestTypeConsumerOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.eraseOp(op); + return success(); + } +}; + /// Test unambiguous overload resolution of replaceOpWithMultiple. This /// function is just to trigger compiler errors. It is never executed. [[maybe_unused]] void testReplaceOpWithMultipleOverloads( @@ -1497,8 +1514,8 @@ struct TestLegalizePatternDriver TestRepetitive1ToNConsumer>(&getContext()); patterns.add( - &getContext(), converter); + TestBlockArgReplace, TestReplaceWithValidConsumer, + TestTypeConsumerOpPattern>(&getContext(), converter); patterns.add(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter);