diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 576481a6e7215..35f7290a235c2 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -675,9 +675,9 @@ class RewriterBase : public OpBuilder { /// true. Also notify the listener about every in-place op modification (for /// every use that was replaced). The optional `allUsesReplaced` flag is set /// to "true" if all uses were replaced. - void replaceUsesWithIf(Value from, Value to, - function_ref functor, - bool *allUsesReplaced = nullptr); + virtual void replaceUsesWithIf(Value from, Value to, + function_ref functor, + bool *allUsesReplaced = nullptr); void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref functor, bool *allUsesReplaced = nullptr); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5ac9e26e8636d..9f449080b0f37 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -903,6 +903,27 @@ class ConversionPatternRewriter final : public PatternRewriter { replaceAllUsesWith(from, ValueRange{to}); } + /// Replace the uses of `from` with `to` for which the `functor` returns + /// "true". The conversion driver will try to reconcile all type mismatches + /// that still exist at the end of the conversion with materializations. + /// This function supports both 1:1 and 1:N replacements. + /// + /// Note: The functor is also applied to builtin.unrealized_conversion_cast + /// ops that may have been inserted by the conversion driver. Some uses may + /// have been wrapped in unrealized_conversion_cast ops due to type changes. + /// + /// Note: This function is not supported in rollback mode. Calling it in + /// rollback mode will trigger an assertion. Furthermore, the + /// `allUsesReplaced` flag is not supported yet. + void replaceUsesWithIf(Value from, Value to, + function_ref functor, + bool *allUsesReplaced = nullptr) override { + replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced); + } + void replaceUsesWithIf(Value from, ValueRange to, + function_ref functor, + bool *allUsesReplaced = nullptr); + /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case /// of failure, the remapped value otherwise. diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 2220f61ed8a07..ddd94f5d03042 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -283,8 +283,16 @@ static void restoreByValRefArgumentType( Type resTy = typeConverter.convertType( cast(byValRefAttr->getValue()).getValue()); - Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); - rewriter.replaceAllUsesWith(arg, valueArg); + auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); + if (!rewriter.getConfig().allowPatternRollback) { + rewriter.replaceAllUsesExcept(arg, loadOp, loadOp); + } else { + // replaceAllUsesExcept is not supported in rollback mode. The rollback + // mode implementation has a workaround: certain replacements that would + // cause a dominance violation are skipped. + // TODO: Remove workaround. + rewriter.replaceAllUsesWith(arg, loadOp); + } } } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 27e3ec6f64c8f..ccc5b7cb6f229 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -976,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void replaceOp(Operation *op, SmallVector> &&newValues); /// Replace the uses of the given value with the given values. The specified - /// converter is used to build materializations (if necessary). - void replaceAllUsesWith(Value from, ValueRange to, - const TypeConverter *converter); + /// converter is used to build materializations (if necessary). If `functor` + /// is specified, only the uses that the functor returns "true" for are + /// replaced. + void replaceValueUses(Value from, ValueRange to, + const TypeConverter *converter, + function_ref functor = nullptr); /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1202,12 +1205,14 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -/// Replace all uses of `from` with `repl`. -static void performReplaceValue(RewriterBase &rewriter, Value from, - Value repl) { +void ReplaceValueRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter); + if (!repl) + return; + if (isa(repl)) { // `repl` is a block argument. Directly replace all uses. - rewriter.replaceAllUsesWith(from, repl); + rewriter.replaceAllUsesWith(value, repl); return; } @@ -1236,19 +1241,14 @@ static void performReplaceValue(RewriterBase &rewriter, Value from, // `ConversionPatternRewriter` API with the normal `RewriterBase` API. Operation *replOp = repl.getDefiningOp(); Block *replBlock = replOp->getBlock(); - rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { + rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + bool result = + user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + return result; }); } -void ReplaceValueRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter); - if (!repl) - return; - performReplaceValue(rewriter, value, repl); -} - void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); #ifndef NDEBUG @@ -1646,7 +1646,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, /*isPureTypeConversion=*/false) .front(); - replaceAllUsesWith(origArg, mat, converter); + replaceValueUses(origArg, mat, converter); continue; } @@ -1655,14 +1655,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - replaceAllUsesWith(origArg, inputMap->replacementValues, converter); + replaceValueUses(origArg, inputMap->replacementValues, converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - replaceAllUsesWith(origArg, replArgs, converter); + replaceValueUses(origArg, replArgs, converter); } if (config.allowPatternRollback) @@ -1962,8 +1962,24 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::replaceAllUsesWith( - Value from, ValueRange to, const TypeConverter *converter) { +void ConversionPatternRewriterImpl::replaceValueUses( + Value from, ValueRange to, const TypeConverter *converter, + function_ref functor) { + LLVM_DEBUG({ + logger.startLine() << "** Replace Value : '" << from << "'"; + if (auto blockArg = dyn_cast(from)) { + if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { + logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")"; + } else { + logger.getOStream() << " (unlinked block)"; + } + } + if (functor) { + logger.getOStream() << ", conditional replacement"; + } + }); + if (!config.allowPatternRollback) { SmallVector toConv = llvm::to_vector(to); SmallVector repls = @@ -1972,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( Value repl = repls.front(); if (!repl) return; - - performReplaceValue(r, from, repl); + if (functor) { + r.replaceUsesWithIf(from, repl, functor); + } else { + r.replaceAllUsesWith(from, repl); + } return; } @@ -1992,6 +2011,8 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( replacedValues.insert(from); #endif // NDEBUG + assert(!functor && + "conditional value replacement is not supported in rollback mode"); mapping.map(from, to); appendRewrite(from, converter); } @@ -2190,18 +2211,15 @@ FailureOr ConversionPatternRewriter::convertRegionTypes( } void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) { - LLVM_DEBUG({ - impl->logger.startLine() << "** Replace Value : '" << from << "'"; - if (auto blockArg = dyn_cast(from)) { - if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { - impl->logger.getOStream() << " (in region of '" << parentOp->getName() - << "' (" << parentOp << ")\n"; - } else { - impl->logger.getOStream() << " (unlinked block)\n"; - } - } - }); - impl->replaceAllUsesWith(from, to, impl->currentTypeConverter); + impl->replaceValueUses(from, to, impl->currentTypeConverter); +} + +void ConversionPatternRewriter::replaceUsesWithIf( + Value from, ValueRange to, function_ref functor, + bool *allUsesReplaced) { + assert(!allUsesReplaced && + "allUsesReplaced is not supported in a dialect conversion"); + impl->replaceValueUses(from, to, impl->currentTypeConverter, functor); } Value ConversionPatternRewriter::getRemappedValue(Value key) { diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir index 180f16a32991b..14c15ecbe77f0 100644 --- a/mlir/test/Transforms/test-convert-func-op.mlir +++ b/mlir/test/Transforms/test-convert-func-op.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" --split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @add func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } { diff --git a/mlir/test/Transforms/test-legalizer-no-rollback.mlir b/mlir/test/Transforms/test-legalizer-no-rollback.mlir new file mode 100644 index 0000000000000..5f421a35d956b --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-rollback.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @conditional_replacement( +// CHECK-SAME: %[[arg0:.*]]: i43) +// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42 +// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42 +// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42 +// Uses were replaced for dummy_user_1. +// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> () +// Uses were also replaced for dummy_user_2, but not by value_replace. The uses +// were replaced due to the block signature conversion. +// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> () +// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> () +func.func @conditional_replacement(%arg0: i42) { + %repl = "test.legal_op"() : () -> (i42) + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> () + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_2"(%arg0) {} : (i42) -> () + // Perform a conditional 1:N replacement. + "test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp index 75168dde93130..897b11b65b6f2 100644 --- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp +++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp @@ -68,6 +68,9 @@ struct TestConvertFuncOp : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp) + TestConvertFuncOp() = default; + TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {} + void getDependentDialects(DialectRegistry ®istry) const final { registry.insert(); } @@ -92,10 +95,16 @@ struct TestConvertFuncOp patterns.add(typeConverter); LLVMConversionTarget target(getContext()); + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + std::move(patterns), config))) signalPassFailure(); } + + Option allowPatternRollback{*this, "allow-pattern-rollback", + llvm::cl::desc("Allow pattern rollback"), + llvm::cl::init(true)}; }; } // namespace diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 9b64bc691588d..7eabaaeb41500 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern { // Replace the first operand with 2x the second operand. Value from = op->getOperand(0); Value repl = op->getOperand(1); - rewriter.replaceAllUsesWith(from, {repl, repl}); + if (op->hasAttr("conditional")) { + rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) { + return use.getOwner()->hasAttr("replace_uses"); + }); + } else { + rewriter.replaceAllUsesWith(from, {repl, repl}); + } rewriter.modifyOpInPlace(op, [&] { // If the "trigger_rollback" attribute is set, keep the op illegal, so // that a rollback is triggered.