diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index dbf5bf50d60e7..9875f8668b65a 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -154,13 +154,12 @@ namespace { struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, unsigned numReplacements, unsigned numArgReplacements, - unsigned numBlockActions, unsigned numIgnoredOperations, + unsigned numRewrites, unsigned numIgnoredOperations, unsigned numRootUpdates) : numCreatedOps(numCreatedOps), numUnresolvedMaterializations(numUnresolvedMaterializations), numReplacements(numReplacements), - numArgReplacements(numArgReplacements), - numBlockActions(numBlockActions), + numArgReplacements(numArgReplacements), numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), numRootUpdates(numRootUpdates) {} @@ -176,8 +175,8 @@ struct RewriterState { /// The current number of argument replacements queued. unsigned numArgReplacements; - /// The current number of block actions performed. - unsigned numBlockActions; + /// The current number of rewrites performed. + unsigned numRewrites; /// The current number of ignored operations. unsigned numIgnoredOperations; @@ -235,86 +234,6 @@ struct OpReplacement { const TypeConverter *converter; }; -//===----------------------------------------------------------------------===// -// BlockAction - -/// The kind of the block action performed during the rewrite. Actions can be -/// undone if the conversion fails. -enum class BlockActionKind { - Create, - Erase, - Inline, - Move, - Split, - TypeConversion -}; - -/// Original position of the given block in its parent region. During undo -/// actions, the block needs to be placed before `insertBeforeBlock`. -struct BlockPosition { - Region *region; - Block *insertBeforeBlock; -}; - -/// Information needed to undo inlining actions. -/// - the source block -/// - the first inlined operation (could be null if the source block was empty) -/// - the last inlined operation (could be null if the source block was empty) -struct InlineInfo { - Block *sourceBlock; - Operation *firstInlinedInst; - Operation *lastInlinedInst; -}; - -/// The storage class for an undoable block action (one of BlockActionKind), -/// contains the information necessary to undo this action. -struct BlockAction { - static BlockAction getCreate(Block *block) { - return {BlockActionKind::Create, block, {}}; - } - static BlockAction getErase(Block *block, BlockPosition originalPosition) { - return {BlockActionKind::Erase, block, {originalPosition}}; - } - static BlockAction getInline(Block *block, Block *srcBlock, - Block::iterator before) { - BlockAction action{BlockActionKind::Inline, block, {}}; - action.inlineInfo = {srcBlock, - srcBlock->empty() ? nullptr : &srcBlock->front(), - srcBlock->empty() ? nullptr : &srcBlock->back()}; - return action; - } - static BlockAction getMove(Block *block, BlockPosition originalPosition) { - return {BlockActionKind::Move, block, {originalPosition}}; - } - static BlockAction getSplit(Block *block, Block *originalBlock) { - BlockAction action{BlockActionKind::Split, block, {}}; - action.originalBlock = originalBlock; - return action; - } - static BlockAction getTypeConversion(Block *block) { - return BlockAction{BlockActionKind::TypeConversion, block, {}}; - } - - // The action kind. - BlockActionKind kind; - - // A pointer to the block that was created by the action. - Block *block; - - union { - // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and - // contains a pointer to the region that originally contained the block as - // well as the position of the block in that region. - BlockPosition originalPosition; - // In use if kind == BlockActionKind::Split and contains a pointer to the - // block that was split into two parts. - Block *originalBlock; - // In use if kind == BlockActionKind::Inline, and contains the information - // needed to undo the inlining. - InlineInfo inlineInfo; - }; -}; - //===----------------------------------------------------------------------===// // UnresolvedMaterialization @@ -820,6 +739,251 @@ void ArgConverter::insertConversion(Block *newBlock, conversionInfo.insert({newBlock, std::move(info)}); } +//===----------------------------------------------------------------------===// +// IR rewrites +//===----------------------------------------------------------------------===// + +namespace { +/// An IR rewrite that can be committed (upon success) or rolled back (upon +/// failure). +/// +/// The dialect conversion keeps track of IR modifications (requested by the +/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites +/// are directly applied to the IR as the rewriter API is used, some are applied +/// partially, and some are delayed until the `IRRewrite` objects are committed. +class IRRewrite { +public: + /// The kind of the rewrite. Rewrites can be undone if the conversion fails. + enum class Kind { + CreateBlock, + EraseBlock, + InlineBlock, + MoveBlock, + SplitBlock, + BlockTypeConversion + }; + + virtual ~IRRewrite() = default; + + /// Roll back the rewrite. + virtual void rollback() = 0; + + /// Commit the rewrite. + virtual void commit() {} + + Kind getKind() const { return kind; } + + static bool classof(const IRRewrite *rewrite) { return true; } + +protected: + IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) + : kind(kind), rewriterImpl(rewriterImpl) {} + + const Kind kind; + ConversionPatternRewriterImpl &rewriterImpl; +}; + +/// A block rewrite. +class BlockRewrite : public IRRewrite { +public: + /// Return the block that this rewrite operates on. + Block *getBlock() const { return block; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() >= Kind::CreateBlock && + rewrite->getKind() <= Kind::BlockTypeConversion; + } + +protected: + BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Block *block) + : IRRewrite(kind, rewriterImpl), block(block) {} + + // The block that this rewrite operates on. + Block *block; +}; + +/// Creation of a block. Block creations are immediately reflected in the IR. +/// There is no extra work to commit the rewrite. During rollback, the newly +/// created block is erased. +class CreateBlockRewrite : public BlockRewrite { +public: + CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) + : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::CreateBlock; + } + + void rollback() override { + // Unlink all of the operations within this block, they will be deleted + // separately. + auto &blockOps = block->getOperations(); + while (!blockOps.empty()) + blockOps.remove(blockOps.begin()); + block->dropAllDefinedValueUses(); + block->erase(); + } +}; + +/// Erasure of a block. Block erasures are partially reflected in the IR. Erased +/// blocks are immediately unlinked, but only erased when the rewrite is +/// committed. This makes it easier to rollback a block erasure: the block is +/// simply inserted into its original location. +class EraseBlockRewrite : public BlockRewrite { +public: + EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Region *region, Block *insertBeforeBlock) + : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region), + insertBeforeBlock(insertBeforeBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::EraseBlock; + } + + ~EraseBlockRewrite() override { + assert(!block && "rewrite was neither rolled back nor committed"); + } + + void rollback() override { + // The block (owned by this rewrite) was not actually erased yet. It was + // just unlinked. Put it back into its original position. + assert(block && "expected block"); + auto &blockList = region->getBlocks(); + Region::iterator before = insertBeforeBlock + ? Region::iterator(insertBeforeBlock) + : blockList.end(); + blockList.insert(before, block); + block = nullptr; + } + + void commit() override { + // Erase the block. + assert(block && "expected block"); + delete block; + block = nullptr; + } + +private: + // The region in which this block was previously contained. + Region *region; + + // The original successor of this block before it was unlinked. "nullptr" if + // this block was the only block in the region. + Block *insertBeforeBlock; +}; + +/// Inlining of a block. This rewrite is immediately reflected in the IR. +/// Note: This rewrite represents only the inlining of the operations. The +/// erasure of the inlined block is a separate rewrite. +class InlineBlockRewrite : public BlockRewrite { +public: + InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Block *sourceBlock, Block::iterator before) + : BlockRewrite(Kind::InlineBlock, rewriterImpl, block), + sourceBlock(sourceBlock), + firstInlinedInst(sourceBlock->empty() ? nullptr + : &sourceBlock->front()), + lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { + } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::InlineBlock; + } + + void rollback() override { + // Put the operations from the destination block (owned by the rewrite) + // back into the source block. + if (firstInlinedInst) { + assert(lastInlinedInst && "expected operation"); + sourceBlock->getOperations().splice(sourceBlock->begin(), + block->getOperations(), + Block::iterator(firstInlinedInst), + ++Block::iterator(lastInlinedInst)); + } + } + +private: + // The block that originally contained the operations. + Block *sourceBlock; + + // The first inlined operation. + Operation *firstInlinedInst; + + // The last inlined operation. + Operation *lastInlinedInst; +}; + +/// Moving of a block. This rewrite is immediately reflected in the IR. +class MoveBlockRewrite : public BlockRewrite { +public: + MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Region *region, Block *insertBeforeBlock) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), + insertBeforeBlock(insertBeforeBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::MoveBlock; + } + + void rollback() override { + // Move the block back to its original position. + Region::iterator before = + insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); + region->getBlocks().splice(before, block->getParent()->getBlocks(), block); + } + +private: + // The region in which this block was previously contained. + Region *region; + + // The original successor of this block before it was moved. "nullptr" if + // this block was the only block in the region. + Block *insertBeforeBlock; +}; + +/// Splitting of a block. This rewrite is immediately reflected in the IR. +class SplitBlockRewrite : public BlockRewrite { +public: + SplitBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Block *originalBlock) + : BlockRewrite(Kind::SplitBlock, rewriterImpl, block), + originalBlock(originalBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::SplitBlock; + } + + void rollback() override { + // Merge back the block that was split out. + originalBlock->getOperations().splice(originalBlock->end(), + block->getOperations()); + block->dropAllDefinedValueUses(); + block->erase(); + } + +private: + // The original block from which this block was split. + Block *originalBlock; +}; + +/// Block type conversion. This rewrite is partially reflected in the IR. +class BlockTypeConversionRewrite : public BlockRewrite { +public: + BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Block *block) + : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::BlockTypeConversion; + } + + // TODO: Block type conversions are currently committed in + // `ArgConverter::applyRewrites`. This should be done in the "commit" method. + void rollback() override; +}; +} // namespace + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -848,13 +1012,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Reset the state of the rewriter to a previously saved point. void resetState(RewriterState state); - /// Erase any blocks that were unlinked from their regions and stored in block - /// actions. - void eraseDanglingBlocks(); + /// Append a rewrite. Rewrites are committed upon success and rolled back upon + /// failure. + template + void appendRewrite(Args &&...args) { + rewrites.push_back( + std::make_unique(*this, std::forward(args)...)); + } - /// Undo the block actions (motions, splits) one by one in reverse order until - /// "numActionsToKeep" actions remains. - void undoBlockActions(unsigned numActionsToKeep = 0); + /// Undo the rewrites (motions, splits) one by one in reverse order until + /// "numRewritesToKeep" rewrites remains. + void undoRewrites(unsigned numRewritesToKeep = 0); /// Remap the given values to those with potentially different types. Returns /// success if the values could be remapped, failure otherwise. `valueDiagTag` @@ -954,7 +1122,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { SmallVector argReplacements; /// Ordered list of block operations (creations, splits, motions). - SmallVector blockActions; + SmallVector> rewrites; /// A set of operations that should no longer be considered for legalization, /// but were not directly replace/erased/etc. by a pattern. These are @@ -995,6 +1163,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { } // namespace detail } // namespace mlir +void BlockTypeConversionRewrite::rollback() { + // Undo the type conversion. + rewriterImpl.argConverter.discardRewrites(block); +} + /// Detach any operations nested in the given operation from their parent /// blocks, and erase the given operation. This can be used when the nested /// operations are scheduled for erasure themselves, so deleting the regions of @@ -1020,7 +1193,7 @@ void ConversionPatternRewriterImpl::discardRewrites() { for (auto &state : rootUpdates) state.resetOperation(); - undoBlockActions(); + undoRewrites(); // Remove any newly created ops. for (UnresolvedMaterialization &materialization : unresolvedMaterializations) @@ -1083,8 +1256,9 @@ void ConversionPatternRewriterImpl::applyRewrites() { argConverter.applyRewrites(mapping); - // Now that the ops have been erased, also erase dangling blocks. - eraseDanglingBlocks(); + // Commit all rewrites. + for (auto &rewrite : rewrites) + rewrite->commit(); } //===----------------------------------------------------------------------===// @@ -1093,8 +1267,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), unresolvedMaterializations.size(), replacements.size(), argReplacements.size(), - blockActions.size(), ignoredOps.size(), - rootUpdates.size()); + rewrites.size(), ignoredOps.size(), rootUpdates.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -1109,8 +1282,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { mapping.erase(replacedArg); argReplacements.resize(state.numArgReplacements); - // Undo any block actions. - undoBlockActions(state.numBlockActions); + // Undo any rewrites. + undoRewrites(state.numRewrites); // Reset any replaced operations and undo any saved mappings. for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) @@ -1149,76 +1322,11 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { operationsWithChangedResults.pop_back(); } -void ConversionPatternRewriterImpl::eraseDanglingBlocks() { - for (auto &action : blockActions) - if (action.kind == BlockActionKind::Erase) - delete action.block; -} - -void ConversionPatternRewriterImpl::undoBlockActions( - unsigned numActionsToKeep) { - for (auto &action : - llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { - switch (action.kind) { - // Delete the created block. - case BlockActionKind::Create: { - // Unlink all of the operations within this block, they will be deleted - // separately. - auto &blockOps = action.block->getOperations(); - while (!blockOps.empty()) - blockOps.remove(blockOps.begin()); - action.block->dropAllDefinedValueUses(); - action.block->erase(); - break; - } - // Put the block (owned by action) back into its original position. - case BlockActionKind::Erase: { - auto &blockList = action.originalPosition.region->getBlocks(); - Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock; - blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock) - : blockList.end()), - action.block); - break; - } - // Put the instructions from the destination block (owned by the action) - // back into the source block. - case BlockActionKind::Inline: { - Block *sourceBlock = action.inlineInfo.sourceBlock; - if (action.inlineInfo.firstInlinedInst) { - assert(action.inlineInfo.lastInlinedInst && "expected operation"); - sourceBlock->getOperations().splice( - sourceBlock->begin(), action.block->getOperations(), - Block::iterator(action.inlineInfo.firstInlinedInst), - ++Block::iterator(action.inlineInfo.lastInlinedInst)); - } - break; - } - // Move the block back to its original position. - case BlockActionKind::Move: { - Region *originalRegion = action.originalPosition.region; - Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock; - originalRegion->getBlocks().splice( - (insertBeforeBlock ? Region::iterator(insertBeforeBlock) - : originalRegion->end()), - action.block->getParent()->getBlocks(), action.block); - break; - } - // Merge back the block that was split out. - case BlockActionKind::Split: { - action.originalBlock->getOperations().splice( - action.originalBlock->end(), action.block->getOperations()); - action.block->dropAllDefinedValueUses(); - action.block->erase(); - break; - } - // Undo the type conversion. - case BlockActionKind::TypeConversion: { - argConverter.discardRewrites(action.block); - break; - } - } - } - blockActions.resize(numActionsToKeep); +void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { + for (auto &rewrite : + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) + rewrite->rollback(); + rewrites.resize(numRewritesToKeep); } LogicalResult ConversionPatternRewriterImpl::remapValues( @@ -1309,7 +1417,7 @@ FailureOr ConversionPatternRewriterImpl::convertBlockSignature( return failure(); if (Block *newBlock = *result) { if (newBlock != block) - blockActions.push_back(BlockAction::getTypeConversion(newBlock)); + appendRewrite(newBlock); } return result; } @@ -1410,28 +1518,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { Region *region = block->getParent(); Block *origNextBlock = block->getNextNode(); - blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock})); + appendRewrite(block, region, origNextBlock); } void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { if (!previous) { // This is a newly created block. - blockActions.push_back(BlockAction::getCreate(block)); + appendRewrite(block); return; } Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; - blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock})); + appendRewrite(block, previous, prevBlock); } void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, Block *continuation) { - blockActions.push_back(BlockAction::getSplit(continuation, block)); + appendRewrite(continuation, block); } void ConversionPatternRewriterImpl::notifyBlockBeingInlined( Block *block, Block *srcBlock, Block::iterator before) { - blockActions.push_back(BlockAction::getInline(block, srcBlock, before)); + appendRewrite(block, srcBlock, before); } void ConversionPatternRewriterImpl::notifyMatchFailure( @@ -1501,8 +1609,8 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { for (Operation &op : *block) eraseOp(&op); - // Unlink the block from its parent region. The block is kept in the block - // action and will be actually destroyed when rewrites are applied. This + // Unlink the block from its parent region. The block is kept in the rewrite + // object and will be actually destroyed when rewrites are applied. This // allows us to keep the operations in the block live and undo the removal by // re-inserting the block. block->getParent()->getBlocks().remove(block); @@ -1700,11 +1808,11 @@ class OperationLegalizer { RewriterState &curState); /// Legalizes the actions registered during the execution of a pattern. - LogicalResult legalizePatternBlockActions(Operation *op, - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - RewriterState &state, - RewriterState &newState); + LogicalResult + legalizePatternBlockRewrites(Operation *op, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState); LogicalResult legalizePatternCreatedOperations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState); @@ -1986,8 +2094,8 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, // Legalize each of the actions registered during application. RewriterState newState = impl.getCurrentState(); - if (failed(legalizePatternBlockActions(op, rewriter, impl, curState, - newState)) || + if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState, + newState)) || failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || failed(legalizePatternCreatedOperations(rewriter, impl, curState, newState))) { @@ -1998,7 +2106,7 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, return success(); } -LogicalResult OperationLegalizer::legalizePatternBlockActions( +LogicalResult OperationLegalizer::legalizePatternBlockRewrites( Operation *op, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState) { @@ -2006,22 +2114,22 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions( // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. - for (int i = state.numBlockActions, e = newState.numBlockActions; i != e; - ++i) { - auto &action = impl.blockActions[i]; - if (action.kind == BlockActionKind::TypeConversion || - action.kind == BlockActionKind::Erase) + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + BlockRewrite *rewrite = dyn_cast(impl.rewrites[i].get()); + if (!rewrite) + continue; + Block *block = rewrite->getBlock(); + if (isa(rewrite)) continue; // Only check blocks outside of the current operation. - Operation *parentOp = action.block->getParentOp(); - if (!parentOp || parentOp == op || action.block->getNumArguments() == 0) + Operation *parentOp = block->getParentOp(); + if (!parentOp || parentOp == op || block->getNumArguments() == 0) continue; // If the region of the block has a type converter, try to convert the block // directly. - if (auto *converter = - impl.argConverter.getConverter(action.block->getParent())) { - if (failed(impl.convertBlockSignature(action.block, converter))) { + if (auto *converter = impl.argConverter.getConverter(block->getParent())) { + if (failed(impl.convertBlockSignature(block, converter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); return failure(); @@ -2042,9 +2150,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions( // If this operation should be considered for re-legalization, try it. if (operationsToIgnore.insert(parentOp).second && failed(legalize(parentOp, rewriter))) { - LLVM_DEBUG(logFailure( - impl.logger, "operation '{0}'({1}) became illegal after block action", - parentOp->getName(), parentOp)); + LLVM_DEBUG(logFailure(impl.logger, + "operation '{0}'({1}) became illegal after rewrite", + parentOp->getName(), parentOp)); return failure(); } }