Skip to content

Commit 68264af

Browse files
committed
refactor to a class
1 parent 32f16be commit 68264af

File tree

3 files changed

+36
-17
lines changed

3 files changed

+36
-17
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
20462046
: refOperation(beforeOperationBase.getOperation().getRef()),
20472047
block((*refOperation)->getBlock()) {}
20482048

2049+
PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
2050+
: refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
2051+
20492052
void PyInsertionPoint::insert(PyOperationBase &operationBase) {
20502053
PyOperation &operation = operationBase.getOperation();
20512054
if (operation.isAttached())

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,8 @@ class PyInsertionPoint {
841841
PyInsertionPoint(const PyBlock &block);
842842
/// Creates an insertion point positioned before a reference operation.
843843
PyInsertionPoint(PyOperationBase &beforeOperationBase);
844+
/// Creates an insertion point positioned before a reference operation.
845+
PyInsertionPoint(PyOperationRef beforeOperationRef);
844846

845847
/// Shortcut to create an insertion point at the beginning of the block.
846848
static PyInsertionPoint atBlockBegin(PyBlock &block);

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ using namespace mlir::python;
2626

2727
namespace {
2828

29+
class PyPatternRewriter {
30+
public:
31+
PyPatternRewriter(MlirPatternRewriter rewriter)
32+
: rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
33+
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
34+
35+
PyInsertionPoint getInsertionPoint() const {
36+
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
37+
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
38+
39+
if (mlirOperationIsNull(op)) {
40+
MlirOperation owner = mlirBlockGetParentOperation(block);
41+
auto parent = PyOperation::forOperation(ctx, owner);
42+
return PyInsertionPoint(PyBlock(parent, block));
43+
}
44+
45+
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
46+
}
47+
48+
private:
49+
MlirPatternRewriter rewriter [[maybe_unused]];
50+
MlirRewriterBase base;
51+
PyMlirContextRef ctx;
52+
};
53+
2954
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3055
static nb::object objectFromPDLValue(MlirPDLValue value) {
3156
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -84,7 +109,8 @@ class PyPDLPatternModule {
84109
void *userData) -> MlirLogicalResult {
85110
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
86111
return logicalResultFromObject(
87-
f(rewriter, results, objectsFromPDLValues(nValues, values)));
112+
f(PyPatternRewriter(rewriter), results,
113+
objectsFromPDLValues(nValues, values)));
88114
},
89115
fn.ptr());
90116
}
@@ -98,7 +124,8 @@ class PyPDLPatternModule {
98124
void *userData) -> MlirLogicalResult {
99125
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
100126
return logicalResultFromObject(
101-
f(rewriter, results, objectsFromPDLValues(nValues, values)));
127+
f(PyPatternRewriter(rewriter), results,
128+
objectsFromPDLValues(nValues, values)));
102129
},
103130
fn.ptr());
104131
}
@@ -143,21 +170,8 @@ class PyFrozenRewritePatternSet {
143170

144171
/// Create the `mlir.rewrite` here.
145172
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
146-
nb::class_<MlirPatternRewriter>(m, "PatternRewriter")
147-
.def("ip", [](MlirPatternRewriter rewriter) {
148-
MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter);
149-
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
150-
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
151-
MlirOperation owner = mlirBlockGetParentOperation(block);
152-
auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base))
153-
->getRef();
154-
if (mlirOperationIsNull(op)) {
155-
auto parent = PyOperation::forOperation(ctx, owner);
156-
return PyInsertionPoint(PyBlock(parent, block));
157-
}
158-
159-
return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get());
160-
});
173+
nb::class_<PyPatternRewriter>(m, "PyPatternRewriter")
174+
.def("ip", &PyPatternRewriter::getInsertionPoint);
161175
//----------------------------------------------------------------------------
162176
// Mapping of the PDLResultList and PDLModule
163177
//----------------------------------------------------------------------------

0 commit comments

Comments
 (0)