@@ -26,6 +26,31 @@ using namespace mlir::python;
26
26
27
27
namespace {
28
28
29
+ class PyPatternRewriter {
30
+ public:
31
+ PyPatternRewriter (MlirPatternRewriter rewriter)
32
+ : rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
33
+ ctx (PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
34
+
35
+ PyInsertionPoint getInsertionPoint () const {
36
+ MlirBlock block = mlirRewriterBaseGetInsertionBlock (base);
37
+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion (base);
38
+
39
+ if (mlirOperationIsNull (op)) {
40
+ MlirOperation owner = mlirBlockGetParentOperation (block);
41
+ auto parent = PyOperation::forOperation (ctx, owner);
42
+ return PyInsertionPoint (PyBlock (parent, block));
43
+ }
44
+
45
+ return PyInsertionPoint (PyOperation::forOperation (ctx, op));
46
+ }
47
+
48
+ private:
49
+ MlirPatternRewriter rewriter [[maybe_unused]];
50
+ MlirRewriterBase base;
51
+ PyMlirContextRef ctx;
52
+ };
53
+
29
54
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
30
55
static nb::object objectFromPDLValue (MlirPDLValue value) {
31
56
if (MlirValue v = mlirPDLValueAsValue (value); !mlirValueIsNull (v))
@@ -84,7 +109,8 @@ class PyPDLPatternModule {
84
109
void *userData) -> MlirLogicalResult {
85
110
nb::handle f = nb::handle (static_cast <PyObject *>(userData));
86
111
return logicalResultFromObject (
87
- f (rewriter, results, objectsFromPDLValues (nValues, values)));
112
+ f (PyPatternRewriter (rewriter), results,
113
+ objectsFromPDLValues (nValues, values)));
88
114
},
89
115
fn.ptr ());
90
116
}
@@ -98,7 +124,8 @@ class PyPDLPatternModule {
98
124
void *userData) -> MlirLogicalResult {
99
125
nb::handle f = nb::handle (static_cast <PyObject *>(userData));
100
126
return logicalResultFromObject (
101
- f (rewriter, results, objectsFromPDLValues (nValues, values)));
127
+ f (PyPatternRewriter (rewriter), results,
128
+ objectsFromPDLValues (nValues, values)));
102
129
},
103
130
fn.ptr ());
104
131
}
@@ -143,21 +170,8 @@ class PyFrozenRewritePatternSet {
143
170
144
171
// / Create the `mlir.rewrite` here.
145
172
void mlir::python::populateRewriteSubmodule (nb::module_ &m) {
146
- nb::class_<MlirPatternRewriter>(m, " PatternRewriter" )
147
- .def (" ip" , [](MlirPatternRewriter rewriter) {
148
- MlirRewriterBase base = mlirPatternRewriterAsBase (rewriter);
149
- MlirBlock block = mlirRewriterBaseGetInsertionBlock (base);
150
- MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion (base);
151
- MlirOperation owner = mlirBlockGetParentOperation (block);
152
- auto ctx = PyMlirContext::forContext (mlirRewriterBaseGetContext (base))
153
- ->getRef ();
154
- if (mlirOperationIsNull (op)) {
155
- auto parent = PyOperation::forOperation (ctx, owner);
156
- return PyInsertionPoint (PyBlock (parent, block));
157
- }
158
-
159
- return PyInsertionPoint (*PyOperation::forOperation (ctx, op).get ());
160
- });
173
+ nb::class_<PyPatternRewriter>(m, " PyPatternRewriter" )
174
+ .def (" ip" , &PyPatternRewriter::getInsertionPoint);
161
175
// ----------------------------------------------------------------------------
162
176
// Mapping of the PDLResultList and PDLModule
163
177
// ----------------------------------------------------------------------------
0 commit comments