Skip to content

Commit

Permalink
[mlir][DialectConversion] Remove usage of std::distance to track posi…
Browse files Browse the repository at this point in the history
…tion.

Remove use of iterator::difference_type to know where to insert a
moved or erased block during undo actions.

Differential Revision: https://reviews.llvm.org/D85066
  • Loading branch information
MaheshRavishankar committed Aug 3, 2020
1 parent e888886 commit 32f3a9a
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions mlir/lib/Transforms/DialectConversion.cpp
Expand Up @@ -611,12 +611,11 @@ enum class BlockActionKind {
TypeConversion
};

/// Original position of the given block in its parent region. We cannot use
/// a region iterator because it could have been invalidated by other region
/// operations since the position was stored.
/// Original position of the given block in its parent region. During undo
/// actions, the block needs to be placed after `insertAfterBlock`.
struct BlockPosition {
Region *region;
Region::iterator::difference_type position;
Block *insertAfterBlock;
};

/// Information needed to undo the merge actions.
Expand All @@ -634,16 +633,16 @@ struct BlockAction {
static BlockAction getCreate(Block *block) {
return {BlockActionKind::Create, block, {}};
}
static BlockAction getErase(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Erase, block, {originalPos}};
static BlockAction getErase(Block *block, BlockPosition originalPosition) {
return {BlockActionKind::Erase, block, {originalPosition}};
}
static BlockAction getMerge(Block *block, Block *sourceBlock) {
BlockAction action{BlockActionKind::Merge, block, {}};
action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
return action;
}
static BlockAction getMove(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Move, block, {originalPos}};
static BlockAction getMove(Block *block, BlockPosition originalPosition) {
return {BlockActionKind::Move, block, {originalPosition}};
}
static BlockAction getSplit(Block *block, Block *originalBlock) {
BlockAction action{BlockActionKind::Split, block, {}};
Expand Down Expand Up @@ -988,9 +987,11 @@ void ConversionPatternRewriterImpl::undoBlockActions(
// Put the block (owned by action) back into its original position.
case BlockActionKind::Erase: {
auto &blockList = action.originalPosition.region->getBlocks();
blockList.insert(
std::next(blockList.begin(), action.originalPosition.position),
action.block);
Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
blockList.insert((insertAfterBlock
? std::next(Region::iterator(insertAfterBlock))
: blockList.end()),
action.block);
break;
}
// Split the block at the position which was originally the end of the
Expand All @@ -1010,8 +1011,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
// Move the block back to its original position.
case BlockActionKind::Move: {
Region *originalRegion = action.originalPosition.region;
Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
originalRegion->getBlocks().splice(
std::next(originalRegion->begin(), action.originalPosition.position),
(insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
: originalRegion->end()),
action.block->getParent()->getBlocks(), action.block);
break;
}
Expand Down Expand Up @@ -1189,8 +1192,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,

void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
Region *region = block->getParent();
auto position = std::distance(region->begin(), Region::iterator(block));
blockActions.push_back(BlockAction::getErase(block, {region, position}));
Block *origPrevBlock = block->getPrevNode();
blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
}

void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
Expand All @@ -1209,10 +1212,12 @@ void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,

void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
Region &region, Region &parent, Region::iterator before) {
Block *origPrevBlock = nullptr;
for (auto &pair : llvm::enumerate(region)) {
Block &block = pair.value();
Region::iterator::difference_type position = pair.index();
blockActions.push_back(BlockAction::getMove(&block, {&region, position}));
blockActions.push_back(
BlockAction::getMove(&block, {&region, origPrevBlock}));
origPrevBlock = █
}
}

Expand Down

0 comments on commit 32f3a9a

Please sign in to comment.