diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e3500b3f9446d..3e11e00b9d4b4 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -409,9 +409,9 @@ class RewriterBase : public OpBuilder { /// Notify the listener that the specified operation was modified in-place. virtual void notifyOperationModified(Operation *op) {} - /// Notify the listener that the specified operation is about to be replaced - /// with another operation. This is called before the uses of the old - /// operation have been changed. + /// Notify the listener that all uses of the specified operation's results + /// are about to be replaced with the results of another operation. This is + /// called before the uses of the old operation have been changed. /// /// By default, this function calls the "operation replaced with values" /// notification. @@ -420,9 +420,10 @@ class RewriterBase : public OpBuilder { notifyOperationReplaced(op, replacement->getResults()); } - /// Notify the listener that the specified operation is about to be replaced - /// with the a range of values, potentially produced by other operations. - /// This is called before the uses of the operation have been changed. + /// Notify the listener that all uses of the specified operation's results + /// are about to be replaced with the a range of values, potentially + /// produced by other operations. This is called before the uses of the + /// operation have been changed. virtual void notifyOperationReplaced(Operation *op, ValueRange replacement) {} @@ -613,12 +614,14 @@ class RewriterBase : public OpBuilder { /// Find uses of `from` and replace them with `to`. Also notify the listener /// about every in-place op modification (for every use that was replaced). - void replaceAllUsesWith(Value from, Value to) { - return replaceAllUsesWith(from.getImpl(), to); + virtual void replaceAllUsesWith(Value from, Value to) { + for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { + Operation *op = operand.getOwner(); + modifyOpInPlace(op, [&]() { operand.set(to); }); + } } - template - void replaceAllUsesWith(IRObjectWithUseList *from, ValueT &&to) { - for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) { + void replaceAllUsesWith(Block *from, Block *to) { + for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); } @@ -628,9 +631,16 @@ class RewriterBase : public OpBuilder { for (auto it : llvm::zip(from, to)) replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); } - void replaceAllUsesWith(Operation *from, ValueRange to) { - replaceAllUsesWith(from->getResults(), to); - } + + /// Find uses of `from` and replace them with `to`. Also notify the listener + /// about every in-place op modification (for every use that was replaced) + /// and that the `from` operation is about to be replaced. + /// + /// Note: This function cannot be called `replaceAllUsesWith` because the + /// overload resolution, when called with an op that can be implicitly + /// converted to a Value, would be ambiguous. + void replaceAllOpUsesWith(Operation *from, ValueRange to); + void replaceAllOpUsesWith(Operation *from, Operation *to); /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. Also notify the listener about every in-place op modification (for @@ -642,9 +652,12 @@ class RewriterBase : public OpBuilder { void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref functor, bool *allUsesReplaced = nullptr); - void replaceUsesWithIf(Operation *from, ValueRange to, - function_ref functor, - bool *allUsesReplaced = nullptr) { + // Note: This function cannot be called `replaceOpUsesWithIf` because the + // overload resolution, when called with an op that can be implicitly + // converted to a Value, would be ambiguous. + void replaceOpUsesWithIf(Operation *from, ValueRange to, + function_ref functor, + bool *allUsesReplaced = nullptr) { replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced); } @@ -652,9 +665,9 @@ class RewriterBase : public OpBuilder { /// the listener about every in-place op modification (for every use that was /// replaced). The optional `allUsesReplaced` flag is set to "true" if all /// uses were replaced. - void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, - bool *allUsesReplaced = nullptr) { - replaceUsesWithIf( + void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, + Block *block, bool *allUsesReplaced = nullptr) { + replaceOpUsesWithIf( op, newValues, [block](OpOperand &use) { return block->getParentOp()->isProperAncestor(use.getOwner()); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 83198c9b0db54..1797ee0876e43 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -697,9 +697,6 @@ class ConversionPatternRewriter final : public PatternRewriter { Region *region, const TypeConverter &converter, ArrayRef blockConversions); - /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgument from, Value to); - /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case /// of failure, the remapped value otherwise. @@ -720,6 +717,11 @@ class ConversionPatternRewriter final : public PatternRewriter { /// patterns even if a failure is encountered during the rewrite step. bool canRecoverFromRewriteFailure() const override { return true; } + /// Find uses of `from` and replace them with `to`. + /// + /// Note: This function does not convert types. + void replaceAllUsesWith(Value from, Value to) override; + /// PatternRewriter hook for replacing an operation. void replaceOp(Operation *op, ValueRange newValues) override; diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 53b44aa3241bb..d7ed9a196e893 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -310,7 +310,7 @@ static void modifyFuncOpToUseBarePtrCallingConv( Location loc = funcOp.getLoc(); auto placeholder = rewriter.create( loc, typeConverter.convertType(memrefTy)); - rewriter.replaceUsesOfBlockArgument(arg, placeholder); + rewriter.replaceAllUsesWith(arg, placeholder); Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter, memrefTy, arg); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 73d418cb84132..c6d2ddac9dbb1 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -201,7 +201,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, llvmFuncOp.getBody().getArgument(remapping->inputNo); auto placeholder = rewriter.create( loc, getTypeConverter()->convertType(memrefTy)); - rewriter.replaceUsesOfBlockArgument(newArg, placeholder); + rewriter.replaceAllUsesWith(newArg, placeholder); Value desc = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), memrefTy, newArg); rewriter.replaceOp(placeholder, {desc}); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp index 1658ea67a4607..999359c7fa872 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, scalarReplacements.push_back( residualGenericOpBody->getArgument(num + origNumInputs)); bool allUsesReplaced = false; - rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements, - residualGenericOpBody, &allUsesReplaced); + rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements, + residualGenericOpBody, &allUsesReplaced); assert(!allUsesReplaced && "peeled scalar operation is erased when it wasnt expected to be"); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 0a88e40f73ec6..5944a0ea46a14 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() { // Out of line to provide a vtable anchor for the class. } +void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) { + // Notify the listener that we're about to replace this op. + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationReplaced(from, to); + + replaceAllUsesWith(from->getResults(), to); +} + +void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) { + // Notify the listener that we're about to replace this op. + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationReplaced(from, to); + + replaceAllUsesWith(from->getResults(), to->getResults()); +} + /// This method replaces the results of the operation with the specified list of /// values. The number of provided values must match the number of results of /// the operation. The replaced op is erased. @@ -117,12 +133,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - // Notify the listener that we're about to replace this op. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationReplaced(op, newValues); - // Replace all result uses. Also notifies the listener of modifications. - replaceAllUsesWith(op, newValues); + replaceAllOpUsesWith(op, newValues); // Erase op and notify listener. eraseOp(op); @@ -136,12 +148,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { assert(op->getNumResults() == newOp->getNumResults() && "ops have different number of results"); - // Notify the listener that we're about to replace this op. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationReplaced(op, newOp); - // Replace all result uses. Also notifies the listener of modifications. - replaceAllUsesWith(op, newOp->getResults()); + replaceAllOpUsesWith(op, newOp->getResults()); // Erase op and notify listener. eraseOp(op); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c1a261eab8487..dbdfaeeeb28d4 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -153,9 +153,9 @@ namespace { /// This is useful when saving and undoing a set of rewrites. struct RewriterState { RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, - unsigned numReplacedOps) + unsigned numErasedOps) : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), - numReplacedOps(numReplacedOps) {} + numErasedOps(numErasedOps) {} /// The current number of rewrites performed. unsigned numRewrites; @@ -163,8 +163,8 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; - /// The current number of replaced ops that are scheduled for erasure. - unsigned numReplacedOps; + /// The current number of ops that are scheduled for erasure. + unsigned numErasedOps; }; //===----------------------------------------------------------------------===// @@ -190,13 +190,14 @@ class IRRewrite { InlineBlock, MoveBlock, BlockTypeConversion, - ReplaceBlockArg, // Operation rewrites MoveOperation, ModifyOperation, ReplaceOperation, CreateOperation, - UnresolvedMaterialization + UnresolvedMaterialization, + // Value rewrites + ReplaceAllUses }; virtual ~IRRewrite() = default; @@ -231,6 +232,9 @@ class IRRewrite { const ConversionConfig &getConfig() const; + ConversionValueMapping &getMapping(); + +private: const Kind kind; ConversionPatternRewriterImpl &rewriterImpl; }; @@ -243,7 +247,7 @@ class BlockRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::CreateBlock && - rewrite->getKind() <= Kind::ReplaceBlockArg; + rewrite->getKind() <= Kind::BlockTypeConversion; } protected: @@ -469,7 +473,8 @@ class BlockTypeConversionRewrite : public BlockRewrite { /// live users, using the provided `findLiveUser` to search for a user that /// survives the conversion process. LogicalResult - materializeLiveConversions(function_ref findLiveUser); + materializeLiveConversions(OpBuilder &builder, + function_ref findLiveUser); void commit(RewriterBase &rewriter) override; @@ -487,27 +492,6 @@ class BlockTypeConversionRewrite : public BlockRewrite { const TypeConverter *converter; }; -/// Replacing a block argument. This rewrite is not immediately reflected in the -/// IR. An internal IR mapping is updated, but the actual replacement is delayed -/// until the rewrite is committed. -class ReplaceBlockArgRewrite : public BlockRewrite { -public: - ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, BlockArgument arg) - : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {} - - static bool classof(const IRRewrite *rewrite) { - return rewrite->getKind() == Kind::ReplaceBlockArg; - } - - void commit(RewriterBase &rewriter) override; - - void rollback() override; - -private: - BlockArgument arg; -}; - /// An operation rewrite. class OperationRewrite : public IRRewrite { public: @@ -751,6 +735,44 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { /// The original output type. This is only used for argument conversions. Type origOutputType; }; + +/// A value rewrite. +class ValueRewrite : public IRRewrite { +public: + /// Return the operation that this rewrite operates on. + Value getValue() const { return value; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() >= Kind::ReplaceAllUses && + rewrite->getKind() <= Kind::ReplaceAllUses; + } + +protected: + ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Value value) + : IRRewrite(kind, rewriterImpl), value(value) {} + + // The value that this rewrite operates on. + Value value; +}; + +/// Replacing a value. This rewrite is not immediately reflected in the IR. An +/// internal IR mapping is updated, but the actual replacement is delayed until +/// the rewrite is committed. +class ReplaceAllUsesRewrite : public ValueRewrite { +public: + ReplaceAllUsesRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Value value) + : ValueRewrite(Kind::ReplaceAllUses, rewriterImpl, value) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceAllUses; + } + + void commit(RewriterBase &rewriter) override; + + void rollback() override; +}; } // namespace /// Return "true" if there is an operation rewrite that matches the specified @@ -832,8 +854,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// converted. bool isOpIgnored(Operation *op) const; - /// Return "true" if the given operation was replaced or erased. - bool wasOpReplaced(Operation *op) const; + /// Return "true" if the given operation is scheduled for erasure. (It may + /// still be visible in the IR, but should not be accessed.) + bool wasOpErased(Operation *op) const; //===--------------------------------------------------------------------===// // Type Conversion @@ -982,11 +1005,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// tracked separately. SetVector ignoredOps; - /// A set of operations that were replaced/erased. Such ops are not erased - /// immediately but only when the dialect conversion succeeds. In the mean - /// time, they should no longer be considered for legalization and any attempt - /// to modify/access them is invalid rewriter API usage. - SetVector replacedOps; + /// A set of operations that were erased. Such ops are not erased immediately + /// but only when the dialect conversion succeeds. In the mean time, they + /// should no longer be considered for legalization and any attempt to + /// modify/access them is invalid rewriter API usage. + SetVector erasedOps; /// The current type converter, or nullptr if no type converter is currently /// active. @@ -1016,6 +1039,8 @@ const ConversionConfig &IRRewrite::getConfig() const { return rewriterImpl.config; } +ConversionValueMapping &IRRewrite::getMapping() { return rewriterImpl.mapping; } + void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { // Inform the listener about all IR modifications that have already taken // place: References to the original block have been replaced with the new @@ -1030,8 +1055,7 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { llvm::zip_equal(origBlock->getArguments(), argInfo)) { // Handle the case of a 1->0 value mapping. if (!info) { - if (Value newArg = - rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + if (Value newArg = getMapping().lookupOrNull(origArg, origArg.getType())) rewriter.replaceAllUsesWith(origArg, newArg); continue; } @@ -1042,8 +1066,8 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) { - rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault( - castValue, origArg.getType())); + rewriter.replaceAllUsesWith( + origArg, getMapping().lookupOrDefault(castValue, origArg.getType())); } } } @@ -1053,23 +1077,23 @@ void BlockTypeConversionRewrite::rollback() { } LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( - function_ref findLiveUser) { + OpBuilder &builder, function_ref findLiveUser) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(block); + // Process the remapping for each of the original arguments. for (auto it : llvm::enumerate(origBlock->getArguments())) { BlockArgument origArg = it.value(); - // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used. - OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl); - builder.setInsertionPointToStart(block); // If the type of this argument changed and the argument is still live, we // need to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + if (getMapping().lookupOrNull(origArg, origArg.getType())) continue; Operation *liveUser = findLiveUser(origArg); if (!liveUser) continue; - Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg); + Value replacementValue = getMapping().lookupOrDefault(origArg); bool isDroppedArg = replacementValue == origArg; if (!isDroppedArg) builder.setInsertionPointAfterValue(replacementValue); @@ -1094,18 +1118,17 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( << "see existing live user here: " << *liveUser; return failure(); } - rewriterImpl.mapping.map(origArg, newArg); + getMapping().map(origArg, newArg); } return success(); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); - if (!repl) - return; +void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) { + Value repl = getMapping().lookupOrNull(value); + assert(repl && "expected that value is mapped"); if (isa(repl)) { - rewriter.replaceAllUsesWith(arg, repl); + rewriter.replaceAllUsesWith(value, repl); return; } @@ -1114,13 +1137,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { // replacement value. Operation *replOp = cast(repl).getOwner(); Block *replBlock = replOp->getBlock(); - rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) { + rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); }); } -void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); } +void ReplaceAllUsesRewrite::rollback() { getMapping().erase(value); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { auto *listener = dyn_cast_or_null( @@ -1129,7 +1152,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { // Compute replacement values. SmallVector replacements = llvm::map_to_vector(op->getResults(), [&](OpResult result) { - return rewriterImpl.mapping.lookupOrNull(result, result.getType()); + return getMapping().lookupOrNull(result, result.getType()); }); // Notify the listener that the operation is about to be replaced. @@ -1161,7 +1184,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { void ReplaceOperationRewrite::rollback() { for (auto result : op->getResults()) - rewriterImpl.mapping.erase(result); + getMapping().erase(result); } void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { @@ -1180,7 +1203,7 @@ void CreateOperationRewrite::rollback() { void UnresolvedMaterializationRewrite::rollback() { if (getMaterializationKind() == MaterializationKind::Target) { for (Value input : op->getOperands()) - rewriterImpl.mapping.erase(input); + getMapping().erase(input); } op->erase(); } @@ -1205,7 +1228,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management RewriterState ConversionPatternRewriterImpl::getCurrentState() { - return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); + return RewriterState(rewrites.size(), ignoredOps.size(), erasedOps.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -1216,8 +1239,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); - while (replacedOps.size() != state.numReplacedOps) - replacedOps.pop_back(); + while (erasedOps.size() != state.numErasedOps) + erasedOps.pop_back(); } void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { @@ -1282,13 +1305,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { - // Check to see if this operation is ignored or was replaced. - return replacedOps.count(op) || ignoredOps.count(op); + // Check to see if this operation is ignored or was erased. + return erasedOps.count(op) || ignoredOps.count(op); } -bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { - // Check to see if this operation was replaced. - return replacedOps.count(op); +bool ConversionPatternRewriterImpl::wasOpErased(Operation *op) const { + // Check to see if this operation was scheduled for erasure. + return erasedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1434,7 +1457,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( "invalid to provide a replacement value when the argument isn't " "dropped"); mapping.map(origArg, inputMap->replacementValue); - appendRewrite(block, origArg); + appendRewrite(origArg); continue; } @@ -1469,7 +1492,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( } mapping.map(origArg, newArg); - appendRewrite(block, origArg); + appendRewrite(origArg); argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } @@ -1535,8 +1558,8 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); - assert(!wasOpReplaced(op->getParentOp()) && - "attempting to insert into a block within a replaced/erased op"); + assert(!wasOpErased(op->getParentOp()) && + "attempting to insert into a block within an erased op"); if (!previous.isSet()) { // This is a newly created op. @@ -1571,8 +1594,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, appendRewrite(op, currentTypeConverter, resultChanged); - // Mark this operation and all nested ops as replaced. - op->walk([&](Operation *op) { replacedOps.insert(op); }); + // Mark this operation and all nested ops as erased. + op->walk([&](Operation *op) { erasedOps.insert(op); }); } void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { @@ -1583,8 +1606,8 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { - assert(!wasOpReplaced(block->getParentOp()) && - "attempting to insert into a region within a replaced/erased op"); + assert(!wasOpErased(block->getParentOp()) && + "attempting to insert into a region within an erased op"); LLVM_DEBUG( { Operation *parent = block->getParentOp(); @@ -1660,8 +1683,8 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { } void ConversionPatternRewriter::eraseBlock(Block *block) { - assert(!impl->wasOpReplaced(block->getParentOp()) && - "attempting to erase a block within a replaced/erased op"); + assert(!impl->wasOpErased(block->getParentOp()) && + "attempting to erase a block within an erased op"); // Mark all ops for erasure. for (Operation &op : *block) @@ -1678,41 +1701,59 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { Block *ConversionPatternRewriter::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter) { - assert(!impl->wasOpReplaced(region->getParentOp()) && - "attempting to apply a signature conversion to a block within a " - "replaced/erased op"); + assert(!impl->wasOpErased(region->getParentOp()) && + "attempting to apply a signature conversion to a block within an " + "erased op"); return impl->applySignatureConversion(*this, region, conversion, converter); } FailureOr ConversionPatternRewriter::convertRegionTypes( Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { - assert(!impl->wasOpReplaced(region->getParentOp()) && - "attempting to apply a signature conversion to a block within a " - "replaced/erased op"); + assert(!impl->wasOpErased(region->getParentOp()) && + "attempting to apply a signature conversion to a block within an " + "erased op"); return impl->convertRegionTypes(*this, region, converter, entryConversion); } LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( Region *region, const TypeConverter &converter, ArrayRef blockConversions) { - assert(!impl->wasOpReplaced(region->getParentOp()) && - "attempting to apply a signature conversion to a block within a " - "replaced/erased op"); + assert(!impl->wasOpErased(region->getParentOp()) && + "attempting to apply a signature conversion to a block within an " + "erased op"); return impl->convertNonEntryRegionTypes(*this, region, converter, blockConversions); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - Value to) { +void ConversionPatternRewriter::replaceAllUsesWith(Value from, Value to) { +#ifndef NDEBUG LLVM_DEBUG({ - Operation *parentOp = from.getOwner()->getParentOp(); - impl->logger.startLine() << "** Replace Argument : '" << from - << "'(in region of '" << parentOp->getName() - << "'(" << from.getOwner()->getParentOp() << ")\n"; + Block *parentBlock = from.getParentBlock(); + Operation *parentOp = parentBlock ? parentBlock->getParentOp() : nullptr; + impl->logger.startLine() << "** Replace value : '" << from; + if (parentOp) { + impl->logger.getOStream() << "' (in region of '" << parentOp->getName() + << "'(" << parentOp << ")\n"; + } else { + impl->logger.getOStream() << "' (detached)\n"; + } }); - impl->appendRewrite(from.getOwner(), from); - impl->mapping.map(impl->mapping.lookupOrDefault(from), to); + if (OpResult opResult = dyn_cast(from)) { + assert(!impl->wasOpErased(opResult.getDefiningOp()) && + "attempting to replace an OpResult defined by an erased op"); + } + if (OpResult opResult = dyn_cast(to)) { + assert(!impl->wasOpErased(opResult.getDefiningOp()) && + "attempting to replace with an OpResult defined by an erased op"); + } + // A value cannot be replaced multiple times. That would likely require a more + // fine-grained tracking of replacements (i.e., each use must be tracked). + assert(!impl->mapping.lookupOrNull(from) && "value was already replaced"); +#endif // NDEBUG + + impl->appendRewrite(from); + impl->mapping.map(from, to); } Value ConversionPatternRewriter::getRemappedValue(Value key) { @@ -1738,10 +1779,10 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, #ifndef NDEBUG assert(argValues.size() == source->getNumArguments() && "incorrect # of argument replacement values"); - assert(!impl->wasOpReplaced(source->getParentOp()) && - "attempting to inline a block from a replaced/erased op"); - assert(!impl->wasOpReplaced(dest->getParentOp()) && - "attempting to inline a block into a replaced/erased op"); + assert(!impl->wasOpErased(source->getParentOp()) && + "attempting to inline a block from an erased op"); + assert(!impl->wasOpErased(dest->getParentOp()) && + "attempting to inline a block into an erased op"); auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; // The source block will be deleted, so it should not have any users (i.e., // there should be no predecessors). @@ -1762,7 +1803,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // Replace all uses of block arguments. for (auto it : llvm::zip(source->getArguments(), argValues)) - replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); if (fastPath) { // Move all ops at once. @@ -1778,8 +1819,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); + assert(!impl->wasOpErased(op) && "attempting to modify an erased op"); #ifndef NDEBUG impl->pendingRootUpdates.insert(op); #endif @@ -1787,8 +1827,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); + assert(!impl->wasOpErased(op) && "attempting to modify an erased op"); PatternRewriter::finalizeOpModification(op); // There is nothing to do here, we only need to track the operation at the // start of the update. @@ -2204,8 +2243,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( if (!rewrite) continue; Block *block = rewrite->getBlock(); - if (isa(rewrite)) + if (isa(rewrite)) continue; // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); @@ -2688,7 +2726,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( if (auto *blockTypeConversionRewrite = dyn_cast(rewrite.get())) if (failed(blockTypeConversionRewrite->materializeLiveConversions( - findLiveUser))) + rewriter, findLiveUser))) return failure(); } return success(); diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index eff8acdfb33d2..e25867b527b71 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -161,7 +161,7 @@ SmallVector mlir::makeRegionIsolatedFromAbove( rewriter.setInsertionPointToStart(newEntryBlock); for (auto *clonedOp : clonedOperations) { Operation *newOp = rewriter.clone(*clonedOp, map); - rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn); + rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn); } rewriter.mergeBlocks( entryBlock, newEntryBlock, diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index d552f0346644b..78dc3f988a45a 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -427,3 +427,21 @@ func.func @use_of_replaced_bbarg(%arg0: i64) { }) : (i64) -> (i64) "test.invalid"(%0) : (i64) -> () } + +// ----- + +// CHECK: notifyOperationInserted: test.legal_op_b, was unlinked +// CHECK: notifyOperationModified: test.valid +// CHECK: notifyOperationModified: test.illegal_op_h + +// CHECK-LABEL: func @replace_all_uses_with() +func.func @replace_all_uses_with() { + // CHECK: %[[legal:.*]] = "test.legal_op_b"() : () -> i32 + // CHECK: %[[illegal:.*]] = "test.illegal_op_h"() {not_illegal} : () -> i64 + %result = "test.illegal_op_h"() : () -> (i64) + + // replaceAllUsesWith does not perform any type conversion. The uses are + // directly updated during the commit phase. + // CHECK: "test.valid"(%[[legal]]) : (i32) -> () + "test.valid"(%result) : (i64) -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index dfd2f21a5ea24..c19b0d2bc43c8 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1856,6 +1856,7 @@ def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>; def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>; +def ILLegalOpH : TEST_Op<"illegal_op_h">, Results<(outs I64)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins StrAttr:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 2da184bc3d85b..718fbf10f5988 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -489,7 +489,10 @@ struct TestStrictPatternDriver OperationName("test.new_op", op->getContext()).getIdentifier(), op->getOperands(), op->getResultTypes()); } - rewriter.replaceOp(op, newOp->getResults()); + // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". + // A "notifyOperationReplaced" callback is triggered in either case. + rewriter.replaceAllOpUsesWith(op, newOp->getResults()); + rewriter.eraseOp(op); return success(); } }; @@ -782,8 +785,8 @@ struct TestUndoBlockArgReplace : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { auto illegalOp = rewriter.create(op->getLoc(), rewriter.getF32Type()); - rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), - illegalOp->getResult(0)); + rewriter.replaceAllUsesWith(op->getRegion(0).getArgument(0), + illegalOp->getResult(0)); rewriter.modifyOpInPlace(op, [] {}); return success(); } @@ -837,6 +840,24 @@ struct TestUndoPropertiesModification : public ConversionPattern { } }; +/// A pattern that replaces all uses of illegal_op_h with a newly created op +/// that has one i32 result. The old op is marked as "legal". +struct ReplaceAllUsesOfIllegalOp : public ConversionPattern { + ReplaceAllUsesOfIllegalOp(MLIRContext *context) + : ConversionPattern("test.illegal_op_h", /*benefit=*/1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Operation *legalOp = + rewriter.create(op->getLoc(), rewriter.getIntegerType(32)); + rewriter.replaceAllOpUsesWith(op, legalOp); + rewriter.modifyOpInPlace( + op, [&] { op->setAttr("not_illegal", rewriter.getUnitAttr()); }); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing @@ -1117,7 +1138,8 @@ struct TestLegalizePatternDriver TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, - TestUndoPropertiesModification>(&getContext()); + TestUndoPropertiesModification, ReplaceAllUsesOfIllegalOp>( + &getContext()); patterns.add(&getContext(), converter); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1130,6 +1152,8 @@ struct TestLegalizePatternDriver TerminatorOp, OneRegionOp>(); target .addIllegalOp(); + target.addDynamicallyLegalOp( + [](ILLegalOpH op) { return op->hasAttr("not_illegal"); }); target.addDynamicallyLegalOp([](TestReturnOp op) { // Don't allow F32 operands. return llvm::none_of(op.getOperandTypes(),