Skip to content

Commit

Permalink
[mlir] DialectConversion: support erasing blocks
Browse files Browse the repository at this point in the history
PatternRewriter has support for erasing a Block from its parent region, but
this feature has not been implemented for ConversionPatternRewriter that needs
to keep track of and be able to undo block actions. Introduce support for
undoing block erasure in the ConversionPatternRewriter by marking all the ops
it contains for erasure and by detaching the block from its parent region. The
detached block is stored in the action description and is not actually deleted
until the rewrites are applied.

Differential Revision: https://reviews.llvm.org/D80135
  • Loading branch information
ftynse committed May 20, 2020
1 parent 5d5df06 commit df48026
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
58 changes: 52 additions & 6 deletions mlir/lib/Transforms/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ struct ConversionPatternRewriterImpl {

/// The kind of the block action performed during the rewrite. Actions can be
/// undone if the conversion fails.
enum class BlockActionKind { Create, Move, Split, TypeConversion };
enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };

/// Original position of the given block in its parent region. We cannot use
/// a region iterator because it could have been invalidated by other region
Expand All @@ -525,6 +525,9 @@ struct ConversionPatternRewriterImpl {
static BlockAction getCreate(Block *block) {
return {BlockActionKind::Create, block, {}};
}
static BlockAction getErase(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Erase, block, {originalPos}};
}
static BlockAction getMove(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Move, block, {originalPos}};
}
Expand All @@ -544,9 +547,9 @@ struct ConversionPatternRewriterImpl {
Block *block;

union {
// In use if kind == BlockActionKind::Move and contains a pointer to the
// region that originally contained the block as well as the position of
// the block in that region.
// In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
// contains a pointer to the region that originally contained the block as
// well as the position of the block in that region.
BlockPosition originalPosition;
// In use if kind == BlockActionKind::Split and contains a pointer to the
// block that was split into two parts.
Expand All @@ -564,6 +567,10 @@ struct ConversionPatternRewriterImpl {
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);

/// Erase any blocks that were unlinked from their regions and stored in block
/// actions.
void eraseDanglingBlocks();

/// Undo the block actions (motions, splits) one by one in reverse order until
/// "numActionsToKeep" actions remains.
void undoBlockActions(unsigned numActionsToKeep = 0);
Expand All @@ -587,6 +594,9 @@ struct ConversionPatternRewriterImpl {
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues);

/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);

/// Notifies that a block was created.
void notifyCreatedBlock(Block *block);

Expand Down Expand Up @@ -711,6 +721,14 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
ignoredOps.pop_back();
}

void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
for (auto &action : blockActions) {
if (action.kind != BlockActionKind::Erase)
continue;
delete action.block;
}
}

void ConversionPatternRewriterImpl::undoBlockActions(
unsigned numActionsToKeep) {
for (auto &action :
Expand All @@ -727,6 +745,14 @@ void ConversionPatternRewriterImpl::undoBlockActions(
action.block->erase();
break;
}
// Put the block (owned by action) back into its original position.
case BlockActionKind::Erase: {
auto &blockList = action.originalPosition.region->getBlocks();
blockList.insert(
std::next(blockList.begin(), action.originalPosition.position),
action.block);
break;
}
// Move the block back to its original position.
case BlockActionKind::Move: {
Region *originalRegion = action.originalPosition.region;
Expand Down Expand Up @@ -806,6 +832,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
repl.op->erase();

argConverter.applyRewrites(mapping);

// Now that the ops have been erased, also erase dangling blocks.
eraseDanglingBlocks();
}

LogicalResult
Expand Down Expand Up @@ -853,6 +882,12 @@ void ConversionPatternRewriterImpl::replaceOp(Operation *op,
markNestedOpsIgnored(op);
}

void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
Region *region = block->getParent();
auto position = std::distance(region->begin(), Region::iterator(block));
blockActions.push_back(BlockAction::getErase(block, {region, position}));
}

void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
blockActions.push_back(BlockAction::getCreate(block));
}
Expand Down Expand Up @@ -942,7 +977,17 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}

void ConversionPatternRewriter::eraseBlock(Block *block) {
llvm_unreachable("erasing blocks for dialect conversion not implemented");
impl->notifyBlockIsBeingErased(block);

// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);

// Unlink the block from its parent region. The block is kept in the block
// action and will be actually destroyed when rewrites are applied. This
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
block->getParent()->getBlocks().remove(block);
}

/// Apply a signature conversion to the entry block of the given region.
Expand Down Expand Up @@ -1334,7 +1379,8 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
i != e; ++i) {
auto &action = rewriterImpl.blockActions[i];
if (action.kind ==
ConversionPatternRewriterImpl::BlockActionKind::TypeConversion)
ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
continue;

// Convert the block signature.
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,27 @@ func @undo_block_arg_replace() {

// -----

// The op in this function is rewritten to itself (and thus remains illegal) by
// a pattern that removes its second block after adding an operation into it.
// Check that we can undo block removal succesfully.
// CHECK-LABEL: @undo_block_erase
func @undo_block_erase() {
// CHECK: test.undo_block_erase
"test.undo_block_erase"() ({
// expected-remark@-1 {{not legalizable}}
// CHECK: "unregistered.return"()[^[[BB:.*]]]
"unregistered.return"()[^bb1] : () -> ()
// expected-remark@-1 {{not legalizable}}
// CHECK: ^[[BB]]
^bb1:
// CHECK: unregistered.return
"unregistered.return"() : () -> ()
// expected-remark@-1 {{not legalizable}}
}) : () -> ()
}

// -----

// The op in this function is attempted to be rewritten to another illegal op
// with an attached region containing an invalid terminator. The terminator is
// created before the parent op. The deletion should not crash when deleting
Expand Down
33 changes: 25 additions & 8 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,23 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};

/// A rewrite pattern that tests the undo mechanism when erasing a block.
struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase(MLIRContext *ctx)
: ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Block *secondBlock = &*std::next(op->getRegion(0).begin());
rewriter.setInsertionPointToStart(secondBlock);
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
rewriter.eraseBlock(secondBlock);
rewriter.updateRootInPlace(op, [] {});
return success();
}
};

//===----------------------------------------------------------------------===//
// Type-Conversion Rewrite Testing

Expand Down Expand Up @@ -504,14 +521,14 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockArgReplace, TestPassthroughInvalidOp,
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite>(&getContext());
patterns.insert<
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
TestPassthroughInvalidOp, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite>(&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
Expand Down

0 comments on commit df48026

Please sign in to comment.