-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] bind InsertionPointAfter #157156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0b9ec75
to
6e72d75
Compare
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesFull diff: https://github.com/llvm/llvm-project/pull/157156.diff 3 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index bf4950fc1a070..f6f3abf9819e9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2019,7 +2019,7 @@ PyOpView::PyOpView(const nb::object &operationObject)
// PyInsertionPoint.
//------------------------------------------------------------------------------
-PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
+PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
: refOperation(beforeOperationBase.getOperation().getRef()),
@@ -2073,6 +2073,17 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
return PyInsertionPoint{block, std::move(terminatorOpRef)};
}
+PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
+ PyOperation &operation = op.getOperation();
+ MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
+ if (mlirOperationIsNull(nextOperation))
+ return PyInsertionPoint{operation.getBlock()};
+ PyOperationRef nextOpRef =
+ PyOperation::forOperation(operation.getContext(), nextOperation);
+ return PyInsertionPoint{nextOpRef->getOperation().getBlock(),
+ std::move(nextOpRef)};
+}
+
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
@@ -3861,6 +3872,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("block"), "Inserts at the beginning of the block.")
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
nb::arg("block"), "Inserts before the block terminator.")
+ .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+ "Inserts after the operation.")
.def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
"Inserts an operation.")
.def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 0cc0459ebc9a0..1d1ff29533f98 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -821,7 +821,7 @@ class PyInsertionPoint {
public:
/// Creates an insertion point positioned after the last operation in the
/// block, but still inside the block.
- PyInsertionPoint(PyBlock &block);
+ PyInsertionPoint(const PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationBase &beforeOperationBase);
@@ -829,6 +829,9 @@ class PyInsertionPoint {
static PyInsertionPoint atBlockBegin(PyBlock &block);
/// Shortcut to create an insertion point before the block terminator.
static PyInsertionPoint atBlockTerminator(PyBlock &block);
+ /// Shortcut to create an insertion point to the node after the specified
+ /// operation.
+ static PyInsertionPoint after(PyOperationBase &op);
/// Inserts an operation.
void insert(PyOperationBase &operationBase);
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 5eb861a2c0891..a0296227cb050 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -63,6 +63,34 @@ def test_insert_before_operation():
run(test_insert_before_operation)
+# CHECK-LABEL: TEST: test_insert_after_operation
+def test_insert_after_operation():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
+ func.func @foo() -> () {
+ "custom.op1"() : () -> ()
+ "custom.op2"() : () -> ()
+ }
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ custom_op1 = entry_block.operations[0]
+ custom_op2 = entry_block.operations[1]
+ InsertionPoint.after(custom_op1).insert(Operation.create("custom.op3"))
+ InsertionPoint.after(custom_op2).insert(Operation.create("custom.op4"))
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op3"
+ # CHECK: "custom.op2"
+ # CHECK: "custom.op4"
+ module.operation.print()
+
+
+run(test_insert_after_operation)
+
+
# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
ctx = Context()
|
57eb1d6
to
8aae86c
Compare
8aae86c
to
b503ef0
Compare
Flyby comment: does |
Yea it returns an |
b503ef0
to
3996584
Compare
Added a test to demo |
No description provided.