Skip to content

Commit

Permalink
[mlir] DialectConversion: avoid double-free when rolling back op crea…
Browse files Browse the repository at this point in the history
…tion

Dialect conversion infrastructure may roll back op creation by erasing the
operations in the reverse order of their creation. While this guarantees uses
of values will be deleted before their definitions, this does not guarantee
that a parent operation will not be deleted before its child. (This may happen
in case of block inlining or if child operations, such as terminators, are
created in the parent's `build` function before the parent itself.) Handle the
parent/child relationship between ops by removing all child ops from the blocks
before erasing the parent. The child ops remain live, detached from a block,
and will be safely destroyed in their turn, which may come later than that of
the parent.

Differential Revision: https://reviews.llvm.org/D80134
  • Loading branch information
ftynse committed May 20, 2020
1 parent a655144 commit 5d5df06
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 4 deletions.
20 changes: 18 additions & 2 deletions mlir/lib/Transforms/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,22 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
ignoredOps.size(), rootUpdates.size());
}

/// Detach any operations nested in the given operation from their parent
/// blocks, and erase the given operation. This can be used when the nested
/// operations are scheduled for erasure themselves, so deleting the regions of
/// the given operation together with their content would result in double-free.
/// This happens, for example, when rolling back op creation in the reverse
/// order and if the nested ops were created before the parent op. This function
/// does not need to collect nested ops recursively because it is expected to
/// also be called for each nested op when it is about to be deleted.
static void detachNestedAndErase(Operation *op) {
for (Region &region : op->getRegions())
for (Block &block : region.getBlocks())
while (!block.getOperations().empty())
block.getOperations().remove(block.getOperations().begin());
op->erase();
}

void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Reset any operations that were updated in place.
for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
Expand All @@ -686,7 +702,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {

// Pop all of the newly created operations.
while (createdOps.size() != state.numCreatedOps) {
createdOps.back()->erase();
detachNestedAndErase(createdOps.back());
createdOps.pop_back();
}

Expand Down Expand Up @@ -746,7 +762,7 @@ void ConversionPatternRewriterImpl::discardRewrites() {

// Remove any newly created ops.
for (auto *op : llvm::reverse(createdOps))
op->erase();
detachNestedAndErase(op);
}

void ConversionPatternRewriterImpl::applyRewrites() {
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,18 @@ func @undo_block_arg_replace() {
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}

// -----

// The op in this function is attempted to be rewritten to another illegal op
// with an attached region containing an invalid terminator. The terminator is
// created before the parent op. The deletion should not crash when deleting
// created ops in the inverse order, i.e. deleting the parent op and then the
// child op.
// CHECK-LABEL: @undo_child_created_before_parent
func @undo_child_created_before_parent() {
// expected-remark@+1 {{is not legalizable}}
"test.illegal_op_with_region_anchor"() : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
16 changes: 16 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,22 @@ def LegalOpA : TEST_Op<"legal_op_a">,
Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;

// Check that the conversion infrastructure can properly undo the creation of
// operations where an operation was created before its parent, in this case,
// in the parent's builder.
def IllegalOpTerminator : TEST_Op<"illegal_op_terminator", [Terminator]>;
def IllegalOpWithRegion : TEST_Op<"illegal_op_with_region"> {
let skipDefaultBuilders = 1;
let builders = [OpBuilder<"OpBuilder &builder, OperationState &state",
[{ Region *bodyRegion = state.addRegion();
OpBuilder::InsertionGuard g(builder);
Block *body = builder.createBlock(bodyRegion);
builder.setInsertionPointToEnd(body);
builder.create<IllegalOpTerminator>(state.location);
}]>];
}
def IllegalOpWithRegionAnchor : TEST_Op<"illegal_op_with_region_anchor">;

// Check that smaller pattern depths are chosen, i.e. prioritize more direct
// mappings.
def : Pat<(ILLegalOpA), (LegalOpA Test_LegalizerEnum_Success)>;
Expand Down
16 changes: 14 additions & 2 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,18 @@ struct TestBoundedRecursiveRewrite
/// The conversion target handles bounding the recursion of this pattern.
bool hasBoundedRewriteRecursion() const final { return true; }
};

struct TestNestedOpCreationUndoRewrite
: public OpRewritePattern<IllegalOpWithRegionAnchor> {
using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;

LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
PatternRewriter &rewriter) const final {
// rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
return success();
};
};
} // namespace

namespace {
Expand Down Expand Up @@ -498,8 +510,8 @@ struct TestLegalizePatternDriver
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite>(
&getContext());
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite>(&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
Expand Down

0 comments on commit 5d5df06

Please sign in to comment.