diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index db41b9f19e7e8..dec68048dc1d3 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -153,14 +153,12 @@ namespace { /// This is useful when saving and undoing a set of rewrites. struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, - unsigned numReplacements, unsigned numArgReplacements, unsigned numRewrites, unsigned numIgnoredOperations, unsigned numErased) : numCreatedOps(numCreatedOps), numUnresolvedMaterializations(numUnresolvedMaterializations), - numReplacements(numReplacements), - numArgReplacements(numArgReplacements), numRewrites(numRewrites), - numIgnoredOperations(numIgnoredOperations), numErased(numErased) {} + numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), + numErased(numErased) {} /// The current number of created operations. unsigned numCreatedOps; @@ -168,12 +166,6 @@ struct RewriterState { /// The current number of unresolved materializations. unsigned numUnresolvedMaterializations; - /// The current number of replacements queued. - unsigned numReplacements; - - /// The current number of argument replacements queued. - unsigned numArgReplacements; - /// The current number of rewrites performed. unsigned numRewrites; @@ -184,20 +176,6 @@ struct RewriterState { unsigned numErased; }; -//===----------------------------------------------------------------------===// -// OpReplacement - -/// This class represents one requested operation replacement via 'replaceOp' or -/// 'eraseOp`. -struct OpReplacement { - OpReplacement(const TypeConverter *converter = nullptr) - : converter(converter) {} - - /// An optional type converter that can be used to materialize conversions - /// between the new and old values if necessary. - const TypeConverter *converter; -}; - //===----------------------------------------------------------------------===// // UnresolvedMaterialization @@ -321,19 +299,27 @@ class IRRewrite { MoveBlock, SplitBlock, BlockTypeConversion, + ReplaceBlockArg, // Operation rewrites MoveOperation, - ModifyOperation + ModifyOperation, + ReplaceOperation }; virtual ~IRRewrite() = default; - /// Roll back the rewrite. + /// Roll back the rewrite. Operations may be erased during rollback. virtual void rollback() = 0; - /// Commit the rewrite. + /// Commit the rewrite. Operations may be unlinked from their blocks during + /// the commit phase, but they must not be erased yet. This is because + /// internal dialect conversion state (such as `mapping`) may still be using + /// them. Operations must be erased during cleanup. virtual void commit() {} + /// Cleanup operations. Cleanup is called after commit. + virtual void cleanup() {} + Kind getKind() const { return kind; } static bool classof(const IRRewrite *rewrite) { return true; } @@ -360,7 +346,7 @@ class BlockRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::CreateBlock && - rewrite->getKind() <= Kind::BlockTypeConversion; + rewrite->getKind() <= Kind::ReplaceBlockArg; } protected: @@ -428,6 +414,8 @@ class EraseBlockRewrite : public BlockRewrite { void commit() override { // Erase the block. assert(block && "expected block"); + assert(block->empty() && "expected empty block"); + block->dropAllDefinedValueUses(); delete block; block = nullptr; } @@ -589,6 +577,27 @@ class BlockTypeConversionRewrite : public BlockRewrite { const TypeConverter *converter; }; +/// Replacing a block argument. This rewrite is not immediately reflected in the +/// IR. An internal IR mapping is updated, but the actual replacement is delayed +/// until the rewrite is committed. +class ReplaceBlockArgRewrite : public BlockRewrite { +public: + ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Block *block, BlockArgument arg) + : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceBlockArg; + } + + void commit() override; + + void rollback() override; + +private: + BlockArgument arg; +}; + /// An operation rewrite. class OperationRewrite : public IRRewrite { public: @@ -597,7 +606,7 @@ class OperationRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::MoveOperation && - rewrite->getKind() <= Kind::ModifyOperation; + rewrite->getKind() <= Kind::ReplaceOperation; } protected: @@ -698,6 +707,39 @@ class ModifyOperationRewrite : public OperationRewrite { SmallVector successors; void *propertiesStorage = nullptr; }; + +/// Replacing an operation. Erasing an operation is treated as a special case +/// with "null" replacements. This rewrite is not immediately reflected in the +/// IR. An internal IR mapping is updated, but values are not replaced and the +/// original op is not erased until the rewrite is committed. +class ReplaceOperationRewrite : public OperationRewrite { +public: + ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op, const TypeConverter *converter, + bool changedResults) + : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), + converter(converter), changedResults(changedResults) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceOperation; + } + + void commit() override; + + void rollback() override; + + void cleanup() override; + +private: + friend struct OperationConverter; + + /// An optional type converter that can be used to materialize conversions + /// between the new and old values if necessary. + const TypeConverter *converter; + + /// A boolean flag that indicates whether result types have changed or not. + bool changedResults; +}; } // namespace /// Return "true" if there is an operation rewrite that matches the specified @@ -890,6 +932,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void eraseBlock(Block *block) override { if (erased.contains(block)) return; + assert(block->empty() && "expected empty block"); block->dropAllDefinedValueUses(); RewriterBase::eraseBlock(block); } @@ -921,12 +964,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// conversion. SmallVector unresolvedMaterializations; - /// Ordered map of requested operation replacements. - llvm::MapVector replacements; - - /// Ordered vector of any requested block argument replacements. - SmallVector argReplacements; - /// Ordered list of block operations (creations, splits, motions). SmallVector> rewrites; @@ -941,11 +978,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// operation was ignored. SetVector ignoredOps; - /// A vector of indices into `replacements` of operations that were replaced - /// with values with different result types than the original operation, e.g. - /// 1->N conversion of some kind. - SmallVector operationsWithChangedResults; - /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -957,6 +989,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// This allows the user to collect the match failure message. function_ref notifyCallback; + /// A set of pre-existing operations. When mode == OpConversionMode::Analysis, + /// this is populated with ops found to be legalizable to the target. + /// When mode == OpConversionMode::Partial, this is populated with ops found + /// *not* to be legalizable to the target. + DenseSet *trackedOps = nullptr; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1001,6 +1039,8 @@ void BlockTypeConversionRewrite::commit() { } } + assert(origBlock->empty() && "expected empty block"); + origBlock->dropAllDefinedValueUses(); delete origBlock; origBlock = nullptr; } @@ -1063,6 +1103,47 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( return success(); } +void ReplaceBlockArgRewrite::commit() { + Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); + if (!repl) + return; + + if (isa(repl)) { + arg.replaceAllUsesWith(repl); + return; + } + + // If the replacement value is an operation, we check to make sure that we + // don't replace uses that are within the parent operation of the + // replacement value. + Operation *replOp = cast(repl).getOwner(); + Block *replBlock = replOp->getBlock(); + arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + }); +} + +void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); } + +void ReplaceOperationRewrite::commit() { + for (OpResult result : op->getResults()) + if (Value newValue = + rewriterImpl.mapping.lookupOrNull(result, result.getType())) + result.replaceAllUsesWith(newValue); + if (rewriterImpl.trackedOps) + rewriterImpl.trackedOps->erase(op); + // Do not erase the operation yet. It may still be referenced in `mapping`. + op->getBlock()->getOperations().remove(op); +} + +void ReplaceOperationRewrite::rollback() { + for (auto result : op->getResults()) + rewriterImpl.mapping.erase(result); +} + +void ReplaceOperationRewrite::cleanup() { eraseOp(op); } + void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { @@ -1085,51 +1166,16 @@ void ConversionPatternRewriterImpl::discardRewrites() { } void ConversionPatternRewriterImpl::applyRewrites() { - // Apply all of the rewrites replacements requested during conversion. - for (auto &repl : replacements) { - for (OpResult result : repl.first->getResults()) - if (Value newValue = mapping.lookupOrNull(result, result.getType())) - result.replaceAllUsesWith(newValue); - } - - // Apply all of the requested argument replacements. - for (BlockArgument arg : argReplacements) { - Value repl = mapping.lookupOrNull(arg, arg.getType()); - if (!repl) - continue; - - if (isa(repl)) { - arg.replaceAllUsesWith(repl); - continue; - } - - // If the replacement value is an operation, we check to make sure that we - // don't replace uses that are within the parent operation of the - // replacement value. - Operation *replOp = cast(repl).getOwner(); - Block *replBlock = replOp->getBlock(); - arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { - Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); - }); - } + // Commit all rewrites. + for (auto &rewrite : rewrites) + rewrite->commit(); + for (auto &rewrite : rewrites) + rewrite->cleanup(); // Drop all of the unresolved materialization operations created during // conversion. for (auto &mat : unresolvedMaterializations) eraseRewriter.eraseOp(mat.getOp()); - - // In a second pass, erase all of the replaced operations in reverse. This - // allows processing nested operations before their parent region is - // destroyed. Because we process in reverse order, producers may be deleted - // before their users (a pattern deleting a producer and then the consumer) - // so we first drop all uses explicitly. - for (auto &repl : llvm::reverse(replacements)) - eraseRewriter.eraseOp(repl.first); - - // Commit all rewrites. - for (auto &rewrite : rewrites) - rewrite->commit(); } //===----------------------------------------------------------------------===// @@ -1137,28 +1183,14 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), unresolvedMaterializations.size(), - replacements.size(), argReplacements.size(), rewrites.size(), ignoredOps.size(), eraseRewriter.erased.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { - // Reset any replaced arguments. - for (BlockArgument replacedArg : - llvm::drop_begin(argReplacements, state.numArgReplacements)) - mapping.erase(replacedArg); - argReplacements.resize(state.numArgReplacements); - // Undo any rewrites. undoRewrites(state.numRewrites); - // Reset any replaced operations and undo any saved mappings. - for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) - for (auto result : repl.first->getResults()) - mapping.erase(result); - while (replacements.size() != state.numReplacements) - replacements.pop_back(); - // Pop all of the newly inserted materializations. while (unresolvedMaterializations.size() != state.numUnresolvedMaterializations) { @@ -1183,11 +1215,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); - // Reset operations with changed results. - while (!operationsWithChangedResults.empty() && - operationsWithChangedResults.back() >= state.numReplacements) - operationsWithChangedResults.pop_back(); - while (eraseRewriter.erased.size() != state.numErased) eraseRewriter.erased.pop_back(); } @@ -1256,7 +1283,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation was replaced or its parent ignored. - return replacements.count(op) || ignoredOps.count(op->getParentOp()); + return ignoredOps.count(op->getParentOp()) || + hasRewrite(rewrites, op); } void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { @@ -1396,7 +1424,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( "invalid to provide a replacement value when the argument isn't " "dropped"); mapping.map(origArg, inputMap->replacementValue); - argReplacements.push_back(origArg); + appendRewrite(block, origArg); continue; } @@ -1430,7 +1458,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( } mapping.map(origArg, newArg); - argReplacements.push_back(origArg); + appendRewrite(block, origArg); argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } @@ -1462,7 +1490,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); - assert(!replacements.count(op) && "operation was already replaced"); +#ifndef NDEBUG + for (auto &rewrite : rewrites) + if (auto *opReplacement = dyn_cast(rewrite.get())) + assert(opReplacement->getOperation() != op && + "operation was already replaced"); +#endif // NDEBUG // Track if any of the results changed, e.g. erased and replaced with null. bool resultChanged = false; @@ -1477,11 +1510,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, mapping.map(result, newValue); resultChanged |= (newValue.getType() != result.getType()); } - if (resultChanged) - operationsWithChangedResults.push_back(replacements.size()); - // Record the requested operation replacement. - replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter))); + appendRewrite(op, currentTypeConverter, + resultChanged); // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. @@ -1576,8 +1607,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { } void ConversionPatternRewriter::eraseBlock(Block *block) { - impl->notifyBlockIsBeingErased(block); - // Mark all ops for erasure. for (Operation &op : *block) eraseOp(&op); @@ -1586,6 +1615,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { // object and will be actually destroyed when rewrites are applied. This // allows us to keep the operations in the block live and undo the removal by // re-inserting the block. + impl->notifyBlockIsBeingErased(block); block->getParent()->getBlocks().remove(block); } @@ -1615,7 +1645,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(in region of '" << parentOp->getName() << "'(" << from.getOwner()->getParentOp() << ")\n"; }); - impl->argReplacements.push_back(from); + impl->appendRewrite(from.getOwner(), from); impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } @@ -2039,16 +2069,13 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, #ifndef NDEBUG assert(impl.pendingRootUpdates.empty() && "dangling root updates"); - // Check that the root was either replaced or updated in place. + auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites); auto replacedRoot = [&] { - return llvm::any_of( - llvm::drop_begin(impl.replacements, curState.numReplacements), - [op](auto &it) { return it.first == op; }); + return hasRewrite(newRewrites, op); }; auto updatedRootInPlace = [&] { - return hasRewrite( - llvm::drop_begin(impl.rewrites, curState.numRewrites), op); + return hasRewrite(newRewrites, op); }; assert((replacedRoot() || updatedRootInPlace()) && "expected pattern to replace the root operation"); @@ -2081,7 +2108,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( if (!rewrite) continue; Block *block = rewrite->getBlock(); - if (isa(rewrite)) + if (isa(rewrite)) continue; // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); @@ -2476,6 +2504,7 @@ LogicalResult OperationConverter::convertOperations( ConversionPatternRewriter rewriter(ops.front()->getContext()); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); rewriterImpl.notifyCallback = notifyCallback; + rewriterImpl.trackedOps = trackedOps; for (auto *op : toConvert) if (failed(convert(rewriter, op))) @@ -2493,13 +2522,6 @@ LogicalResult OperationConverter::convertOperations( rewriterImpl.discardRewrites(); } else { rewriterImpl.applyRewrites(); - - // It is possible for a later pattern to erase an op that was originally - // identified as illegal and added to the trackedOps, remove it now after - // replacements have been computed. - if (trackedOps) - for (auto &repl : rewriterImpl.replacements) - trackedOps->erase(repl.first); } return success(); } @@ -2513,21 +2535,20 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) return failure(); - if (rewriterImpl.operationsWithChangedResults.empty()) - return success(); - // Process requested operation replacements. - for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); - i != e; ++i) { - unsigned replIdx = rewriterImpl.operationsWithChangedResults[i]; - auto &repl = *(rewriterImpl.replacements.begin() + replIdx); - for (OpResult result : repl.first->getResults()) { + for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { + auto *opReplacement = + dyn_cast(rewriterImpl.rewrites[i].get()); + if (!opReplacement || !opReplacement->changedResults) + continue; + Operation *op = opReplacement->getOperation(); + for (OpResult result : op->getResults()) { Value newValue = rewriterImpl.mapping.lookupOrNull(result); // If the operation result was replaced with null, all of the uses of this // value should be replaced. if (!newValue) { - if (failed(legalizeErasedResult(repl.first, result, rewriterImpl))) + if (failed(legalizeErasedResult(op, result, rewriterImpl))) return failure(); continue; } @@ -2541,15 +2562,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { inverseMapping = rewriterImpl.mapping.getInverse(); // Legalize this result. - rewriter.setInsertionPoint(repl.first); - if (failed(legalizeChangedResultType(repl.first, result, newValue, - repl.second.converter, rewriter, + rewriter.setInsertionPoint(op); + if (failed(legalizeChangedResultType(op, result, newValue, + opReplacement->converter, rewriter, rewriterImpl, *inverseMapping))) return failure(); - - // Update the end iterator for this loop in the case it was updated - // when legalizing generated conversion operations. - e = rewriterImpl.operationsWithChangedResults.size(); } } return success();