Skip to content

Commit 7aec3f2

Browse files
[MLIR][Python] Support Python-defined rewrite patterns (#162699)
This PR adds support for defining custom **`RewritePattern`** implementations directly in the Python bindings. Previously, users could define similar patterns using the PDL dialect’s bindings. However, for more complex patterns, this often required writing multiple Python callbacks as PDL native constraints or rewrite functions, which made the overall logic less intuitive—though it could be more performant than a pure Python implementation (especially for simple patterns). With this change, we introduce an additional, straightforward way to define patterns purely in Python, complementing the existing PDL-based approach. ### Example ```python def to_muli(op, rewriter): with rewriter.ip: new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) rewriter.replace_op(op, new_op.owner) with Context(): patterns = RewritePatternSet() patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli frozen = patterns.freeze() module = ... apply_patterns_and_fold_greedily(module, frozen) ``` --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
1 parent 0b462f6 commit 7aec3f2

File tree

5 files changed

+318
-21
lines changed

5 files changed

+318
-21
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
3838
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
3939
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
4040
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
41+
DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
4142

4243
//===----------------------------------------------------------------------===//
4344
/// RewriterBase API inherited from OpBuilder
@@ -302,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
302303
/// FrozenRewritePatternSet API
303304
//===----------------------------------------------------------------------===//
304305

306+
/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet.
307+
/// Note that the ownership of the input set is transferred into the frozen set
308+
/// after this call.
305309
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
306-
mlirFreezeRewritePattern(MlirRewritePatternSet op);
310+
mlirFreezeRewritePattern(MlirRewritePatternSet set);
307311

312+
/// Destroy the given MlirFrozenRewritePatternSet.
308313
MLIR_CAPI_EXPORTED void
309-
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
314+
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set);
310315

311316
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
312317
MlirOperation op, MlirFrozenRewritePatternSet patterns,
@@ -324,6 +329,51 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
324329
MLIR_CAPI_EXPORTED MlirRewriterBase
325330
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
326331

332+
//===----------------------------------------------------------------------===//
333+
/// RewritePattern API
334+
//===----------------------------------------------------------------------===//
335+
336+
/// Callbacks to construct a rewrite pattern.
337+
typedef struct {
338+
/// Optional constructor for the user data.
339+
/// Set to nullptr to disable it.
340+
void (*construct)(void *userData);
341+
/// Optional destructor for the user data.
342+
/// Set to nullptr to disable it.
343+
void (*destruct)(void *userData);
344+
/// The callback function to match against code rooted at the specified
345+
/// operation, and perform the rewrite if the match is successful,
346+
/// corresponding to RewritePattern::matchAndRewrite.
347+
MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
348+
MlirOperation op,
349+
MlirPatternRewriter rewriter,
350+
void *userData);
351+
} MlirRewritePatternCallbacks;
352+
353+
/// Create a rewrite pattern that matches the operation
354+
/// with the given rootName, corresponding to mlir::OpRewritePattern.
355+
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
356+
MlirStringRef rootName, unsigned benefit, MlirContext context,
357+
MlirRewritePatternCallbacks callbacks, void *userData,
358+
size_t nGeneratedNames, MlirStringRef *generatedNames);
359+
360+
//===----------------------------------------------------------------------===//
361+
/// RewritePatternSet API
362+
//===----------------------------------------------------------------------===//
363+
364+
/// Create an empty MlirRewritePatternSet.
365+
MLIR_CAPI_EXPORTED MlirRewritePatternSet
366+
mlirRewritePatternSetCreate(MlirContext context);
367+
368+
/// Destruct the given MlirRewritePatternSet.
369+
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
370+
371+
/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
372+
/// Note that the ownership of the pattern is transferred to the set after this
373+
/// call.
374+
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
375+
MlirRewritePattern pattern);
376+
327377
//===----------------------------------------------------------------------===//
328378
/// PDLPatternModule API
329379
//===----------------------------------------------------------------------===//

mlir/include/mlir/CAPI/Rewrite.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,7 @@
2020
#include "mlir/IR/PatternMatch.h"
2121

2222
DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase)
23+
DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern)
24+
DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
2325

2426
#endif // MLIR_CAPIREWRITER_H

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ class PyPatternRewriter {
4545
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
4646
}
4747

48+
void replaceOp(MlirOperation op, MlirOperation newOp) {
49+
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
50+
}
51+
52+
void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
53+
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
54+
}
55+
56+
void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
57+
4858
private:
4959
MlirRewriterBase base;
5060
PyMlirContextRef ctx;
@@ -165,13 +175,116 @@ class PyFrozenRewritePatternSet {
165175
MlirFrozenRewritePatternSet set;
166176
};
167177

178+
class PyRewritePatternSet {
179+
public:
180+
PyRewritePatternSet(MlirContext ctx)
181+
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
182+
~PyRewritePatternSet() {
183+
if (set.ptr)
184+
mlirRewritePatternSetDestroy(set);
185+
}
186+
187+
void add(MlirStringRef rootName, unsigned benefit,
188+
const nb::callable &matchAndRewrite) {
189+
MlirRewritePatternCallbacks callbacks;
190+
callbacks.construct = [](void *userData) {
191+
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
192+
};
193+
callbacks.destruct = [](void *userData) {
194+
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
195+
};
196+
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
197+
MlirPatternRewriter rewriter,
198+
void *userData) -> MlirLogicalResult {
199+
nb::handle f(static_cast<PyObject *>(userData));
200+
nb::object res = f(op, PyPatternRewriter(rewriter));
201+
return logicalResultFromObject(res);
202+
};
203+
MlirRewritePattern pattern = mlirOpRewritePattenCreate(
204+
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
205+
/* nGeneratedNames */ 0,
206+
/* generatedNames */ nullptr);
207+
mlirRewritePatternSetAdd(set, pattern);
208+
}
209+
210+
PyFrozenRewritePatternSet freeze() {
211+
MlirRewritePatternSet s = set;
212+
set.ptr = nullptr;
213+
return mlirFreezeRewritePattern(s);
214+
}
215+
216+
private:
217+
MlirRewritePatternSet set;
218+
MlirContext ctx;
219+
};
220+
168221
} // namespace
169222

170223
/// Create the `mlir.rewrite` here.
171224
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
172-
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
173-
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
174-
"The current insertion point of the PatternRewriter.");
225+
//----------------------------------------------------------------------------
226+
// Mapping of the PatternRewriter
227+
//----------------------------------------------------------------------------
228+
nb::
229+
class_<PyPatternRewriter>(m, "PatternRewriter")
230+
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
231+
"The current insertion point of the PatternRewriter.")
232+
.def(
233+
"replace_op",
234+
[](PyPatternRewriter &self, MlirOperation op,
235+
MlirOperation newOp) { self.replaceOp(op, newOp); },
236+
"Replace an operation with a new operation.", nb::arg("op"),
237+
nb::arg("new_op"),
238+
// clang-format off
239+
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
240+
// clang-format on
241+
)
242+
.def(
243+
"replace_op",
244+
[](PyPatternRewriter &self, MlirOperation op,
245+
const std::vector<MlirValue> &values) {
246+
self.replaceOp(op, values);
247+
},
248+
"Replace an operation with a list of values.", nb::arg("op"),
249+
nb::arg("values"),
250+
// clang-format off
251+
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
252+
// clang-format on
253+
)
254+
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
255+
nb::arg("op"),
256+
// clang-format off
257+
nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
258+
// clang-format on
259+
);
260+
261+
//----------------------------------------------------------------------------
262+
// Mapping of the RewritePatternSet
263+
//----------------------------------------------------------------------------
264+
nb::class_<MlirRewritePattern>(m, "RewritePattern");
265+
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
266+
.def(
267+
"__init__",
268+
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
269+
new (&self) PyRewritePatternSet(context.get()->get());
270+
},
271+
"context"_a = nb::none())
272+
.def(
273+
"add",
274+
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
275+
unsigned benefit) {
276+
std::string opName =
277+
nb::cast<std::string>(root.attr("OPERATION_NAME"));
278+
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
279+
fn);
280+
},
281+
"root"_a, "fn"_a, "benefit"_a = 1,
282+
"Add a new rewrite pattern on the given root operation with the "
283+
"callable as the matching and rewriting function and the given "
284+
"benefit.")
285+
.def("freeze", &PyRewritePatternSet::freeze,
286+
"Freeze the pattern set into a frozen one.");
287+
175288
//----------------------------------------------------------------------------
176289
// Mapping of the PDLResultList and PDLModule
177290
//----------------------------------------------------------------------------
@@ -237,7 +350,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
237350
.def(
238351
"freeze",
239352
[](PyPDLPatternModule &self) {
240-
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
353+
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
241354
mlirRewritePatternSetFromPDLPatternModule(self.get())));
242355
},
243356
nb::keep_alive<0, 1>())

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
270270
/// RewritePatternSet and FrozenRewritePatternSet API
271271
//===----------------------------------------------------------------------===//
272272

273-
static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
274-
assert(module.ptr && "unexpected null module");
275-
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
276-
}
277-
278-
static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
279-
return {module};
280-
}
281-
282273
static inline mlir::FrozenRewritePatternSet *
283274
unwrap(MlirFrozenRewritePatternSet module) {
284275
assert(module.ptr && "unexpected null module");
@@ -290,15 +281,16 @@ wrap(mlir::FrozenRewritePatternSet *module) {
290281
return {module};
291282
}
292283

293-
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
294-
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
295-
op.ptr = nullptr;
284+
MlirFrozenRewritePatternSet
285+
mlirFreezeRewritePattern(MlirRewritePatternSet set) {
286+
auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
287+
set.ptr = nullptr;
296288
return wrap(m);
297289
}
298290

299-
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
300-
delete unwrap(op);
301-
op.ptr = nullptr;
291+
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
292+
delete unwrap(set);
293+
set.ptr = nullptr;
302294
}
303295

304296
MlirLogicalResult
@@ -332,6 +324,77 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
332324
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
333325
}
334326

327+
//===----------------------------------------------------------------------===//
328+
/// RewritePattern API
329+
//===----------------------------------------------------------------------===//
330+
331+
namespace mlir {
332+
333+
class ExternalRewritePattern : public mlir::RewritePattern {
334+
public:
335+
ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
336+
StringRef rootName, PatternBenefit benefit,
337+
MLIRContext *context,
338+
ArrayRef<StringRef> generatedNames)
339+
: RewritePattern(rootName, benefit, context, generatedNames),
340+
callbacks(callbacks), userData(userData) {
341+
if (callbacks.construct)
342+
callbacks.construct(userData);
343+
}
344+
345+
~ExternalRewritePattern() {
346+
if (callbacks.destruct)
347+
callbacks.destruct(userData);
348+
}
349+
350+
LogicalResult matchAndRewrite(Operation *op,
351+
PatternRewriter &rewriter) const override {
352+
return unwrap(callbacks.matchAndRewrite(
353+
wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
354+
wrap(&rewriter), userData));
355+
}
356+
357+
private:
358+
MlirRewritePatternCallbacks callbacks;
359+
void *userData;
360+
};
361+
362+
} // namespace mlir
363+
364+
MlirRewritePattern mlirOpRewritePattenCreate(
365+
MlirStringRef rootName, unsigned benefit, MlirContext context,
366+
MlirRewritePatternCallbacks callbacks, void *userData,
367+
size_t nGeneratedNames, MlirStringRef *generatedNames) {
368+
std::vector<mlir::StringRef> generatedNamesVec;
369+
generatedNamesVec.reserve(nGeneratedNames);
370+
for (size_t i = 0; i < nGeneratedNames; ++i) {
371+
generatedNamesVec.push_back(unwrap(generatedNames[i]));
372+
}
373+
return wrap(new mlir::ExternalRewritePattern(
374+
callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
375+
unwrap(context), generatedNamesVec));
376+
}
377+
378+
//===----------------------------------------------------------------------===//
379+
/// RewritePatternSet API
380+
//===----------------------------------------------------------------------===//
381+
382+
MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
383+
return wrap(new mlir::RewritePatternSet(unwrap(context)));
384+
}
385+
386+
void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
387+
delete unwrap(set);
388+
}
389+
390+
void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
391+
MlirRewritePattern pattern) {
392+
std::unique_ptr<mlir::RewritePattern> patternPtr(
393+
const_cast<mlir::RewritePattern *>(unwrap(pattern)));
394+
pattern.ptr = nullptr;
395+
unwrap(set)->add(std::move(patternPtr));
396+
}
397+
335398
//===----------------------------------------------------------------------===//
336399
/// PDLPatternModule API
337400
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)