diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index bd972df271ed6..52067f1b95ad4 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -548,7 +548,13 @@ class MoveBlockRewrite : public BlockRewrite { // Move the block back to its original position. Region::iterator before = insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); - region->getBlocks().splice(before, block->getParent()->getBlocks(), block); + if (Region *currentParent = block->getParent()) { + // Block is still in a region, use cheap splice to move it back. + region->getBlocks().splice(before, currentParent->getBlocks(), block); + return; + } + // Block was orphaned by a prior rollback, can't splice. + region->getBlocks().insert(before, block); } private: diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir index 4bcca6b7e5228..f6569201842b7 100644 --- a/mlir/test/Transforms/test-legalizer-rollback.mlir +++ b/mlir/test/Transforms/test-legalizer-rollback.mlir @@ -138,6 +138,18 @@ func.func @test_properties_rollback() { // ----- +// CHECK-LABEL: func @test_undo_block_move_detached +func.func @test_undo_block_move_detached() { + // expected-remark @below{{op 'test.undo_detached_block_move' is not legalizable}} + "test.undo_detached_block_move"() ({ + ^bb0(%arg0: i64): + "test.return"() : () -> () + }) : () -> () + "test.return"() : () -> () +} + +// ----- + // expected-remark@+1 {{applyPartialConversion failed}} builtin.module { // Test that region cloning can be properly undone. diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 6c564a6592c11..6c44ace831e96 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1042,6 +1042,38 @@ struct TestUndoPropertiesModification : public ConversionPattern { } }; +/// A pattern that tests the undo mechanism for a block move if the block was +/// moved to a detached region. The block is first moved to a detached region +/// and then a new operation is created with that region. During rollback, first +/// the `CreateOperationRewrite` is rolled back, causing the block to be +/// orphaned, i.e., removed from the region. Only then the `MoveBlockRewrite` is +/// rolled back, which now can't access the region anymore. The test ensures +/// that the rollback still works and doesn't try to access the orphaned block's +/// containing region, leading to segfault. +struct TestUndoMoveDetachedBlock : public ConversionPattern { + TestUndoMoveDetachedBlock(MLIRContext *ctx) + : ConversionPattern("test.undo_detached_block_move", /*benefit=*/1, ctx) { + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (op->getNumRegions() != 1) + return failure(); + // Create an illegal operation to trigger rollback. + OperationState state(op->getLoc(), "test.illegal_op_created_after_move", + operands, op->getResultTypes(), {}, BlockRange()); + // Create detached region. + Region *newRegion = state.addRegion(); + // Move blocks to the still detached region + rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, + newRegion->begin()); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing //===----------------------------------------------------------------------===// @@ -1548,7 +1580,7 @@ struct TestLegalizePatternDriver TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, - TestUndoPropertiesModification, TestEraseOp, + TestUndoPropertiesModification, TestUndoMoveDetachedBlock, TestEraseOp, TestReplaceWithValidProducer, TestReplaceWithValidConsumer, TestRepetitive1ToNConsumer>(&getContext()); patterns.add