From 3996584db555d3e3b62d658095072f8d44e7d79d Mon Sep 17 00:00:00 2001 From: makslevental Date: Fri, 5 Sep 2025 11:22:04 -0700 Subject: [PATCH] [MLIR][Python] bind InsertionPointAfter --- mlir/lib/Bindings/Python/IRCore.cpp | 15 +++++++- mlir/lib/Bindings/Python/IRModule.h | 5 ++- mlir/test/python/ir/insertion_point.py | 50 ++++++++++++++++++++++++-- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index bf4950fc1a070..ba00ef712084b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2019,7 +2019,7 @@ PyOpView::PyOpView(const nb::object &operationObject) // PyInsertionPoint. //------------------------------------------------------------------------------ -PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} +PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {} PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) : refOperation(beforeOperationBase.getOperation().getRef()), @@ -2073,6 +2073,17 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { return PyInsertionPoint{block, std::move(terminatorOpRef)}; } +PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) { + PyOperation &operation = op.getOperation(); + PyBlock block = operation.getBlock(); + MlirOperation nextOperation = mlirOperationGetNextInBlock(operation); + if (mlirOperationIsNull(nextOperation)) + return PyInsertionPoint(block); + PyOperationRef nextOpRef = PyOperation::forOperation( + block.getParentOperation()->getContext(), nextOperation); + return PyInsertionPoint{block, std::move(nextOpRef)}; +} + nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { return PyThreadContextEntry::pushInsertionPoint(insertPoint); } @@ -3861,6 +3872,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("block"), "Inserts at the beginning of the block.") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, nb::arg("block"), "Inserts before the block terminator.") + .def_static("after", &PyInsertionPoint::after, nb::arg("operation"), + "Inserts after the operation.") .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), "Inserts an operation.") .def_prop_ro( diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 0cc0459ebc9a0..1d1ff29533f98 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -821,7 +821,7 @@ class PyInsertionPoint { public: /// Creates an insertion point positioned after the last operation in the /// block, but still inside the block. - PyInsertionPoint(PyBlock &block); + PyInsertionPoint(const PyBlock &block); /// Creates an insertion point positioned before a reference operation. PyInsertionPoint(PyOperationBase &beforeOperationBase); @@ -829,6 +829,9 @@ class PyInsertionPoint { static PyInsertionPoint atBlockBegin(PyBlock &block); /// Shortcut to create an insertion point before the block terminator. static PyInsertionPoint atBlockTerminator(PyBlock &block); + /// Shortcut to create an insertion point to the node after the specified + /// operation. + static PyInsertionPoint after(PyOperationBase &op); /// Inserts an operation. void insert(PyOperationBase &operationBase); diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py index 5eb861a2c0891..f48beb25f04b2 100644 --- a/mlir/test/python/ir/insertion_point.py +++ b/mlir/test/python/ir/insertion_point.py @@ -63,6 +63,34 @@ def test_insert_before_operation(): run(test_insert_before_operation) +# CHECK-LABEL: TEST: test_insert_after_operation +def test_insert_after_operation(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.parse( + r""" + func.func @foo() -> () { + "custom.op1"() : () -> () + "custom.op2"() : () -> () + } + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + custom_op1 = entry_block.operations[0] + custom_op2 = entry_block.operations[1] + InsertionPoint.after(custom_op1).insert(Operation.create("custom.op3")) + InsertionPoint.after(custom_op2).insert(Operation.create("custom.op4")) + # CHECK: "custom.op1" + # CHECK: "custom.op3" + # CHECK: "custom.op2" + # CHECK: "custom.op4" + module.operation.print() + + +run(test_insert_after_operation) + + # CHECK-LABEL: TEST: test_insert_at_block_begin def test_insert_at_block_begin(): ctx = Context() @@ -111,14 +139,24 @@ def test_insert_at_terminator(): """ ) entry_block = module.body.operations[0].regions[0].blocks[0] + return_op = entry_block.operations[1] ip = InsertionPoint.at_block_terminator(entry_block) assert ip.block == entry_block - assert ip.ref_operation == entry_block.operations[1] - ip.insert(Operation.create("custom.op2")) + assert ip.ref_operation == return_op + custom_op2 = Operation.create("custom.op2") + ip.insert(custom_op2) + InsertionPoint.after(custom_op2).insert(Operation.create("custom.op3")) # CHECK: "custom.op1" # CHECK: "custom.op2" + # CHECK: "custom.op3" module.operation.print() + try: + InsertionPoint.after(return_op).insert(Operation.create("custom.op4")) + except IndexError as e: + # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. + print(f"ERROR: {e}") + run(test_insert_at_terminator) @@ -187,10 +225,16 @@ def test_insertion_point_context(): with InsertionPoint(entry_block): Operation.create("custom.op2") with InsertionPoint.at_block_begin(entry_block): - Operation.create("custom.opa") + custom_opa = Operation.create("custom.opa") Operation.create("custom.opb") Operation.create("custom.op3") + with InsertionPoint.after(custom_opa): + Operation.create("custom.op4") + Operation.create("custom.op5") + # CHECK: "custom.opa" + # CHECK: "custom.op4" + # CHECK: "custom.op5" # CHECK: "custom.opb" # CHECK: "custom.op1" # CHECK: "custom.op2"