Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
DEFINE_C_API_STRUCT(MlirRewritePattern, const void);

//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
Expand Down Expand Up @@ -302,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
/// FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//

/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet.
/// Note that the ownership of the input set is transferred into the frozen set
/// after this call.
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
mlirFreezeRewritePattern(MlirRewritePatternSet op);
mlirFreezeRewritePattern(MlirRewritePatternSet set);

/// Destroy the given MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set);

MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
MlirOperation op, MlirFrozenRewritePatternSet patterns,
Expand All @@ -324,6 +329,51 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);

//===----------------------------------------------------------------------===//
/// RewritePattern API
//===----------------------------------------------------------------------===//

/// Callbacks to construct a rewrite pattern.
typedef struct {
/// Optional constructor for the user data.
/// Set to nullptr to disable it.
void (*construct)(void *userData);
/// Optional destructor for the user data.
/// Set to nullptr to disable it.
void (*destruct)(void *userData);
/// The callback function to match against code rooted at the specified
/// operation, and perform the rewrite if the match is successful,
/// corresponding to RewritePattern::matchAndRewrite.
MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
MlirOperation op,
MlirPatternRewriter rewriter,
void *userData);
} MlirRewritePatternCallbacks;

/// Create a rewrite pattern that matches the operation
/// with the given rootName, corresponding to mlir::OpRewritePattern.
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames);

//===----------------------------------------------------------------------===//
/// RewritePatternSet API
//===----------------------------------------------------------------------===//

/// Create an empty MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetCreate(MlirContext context);

/// Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);

/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
/// Note that the ownership of the pattern is transferred to the set after this
/// call.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
MlirRewritePattern pattern);

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/CAPI/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@
#include "mlir/IR/PatternMatch.h"

DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase)
DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern)
DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)

#endif // MLIR_CAPIREWRITER_H
121 changes: 117 additions & 4 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class PyPatternRewriter {
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}

void replaceOp(MlirOperation op, MlirOperation newOp) {
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
}

void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
}

void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }

private:
MlirRewriterBase base;
PyMlirContextRef ctx;
Expand Down Expand Up @@ -165,13 +175,116 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};

class PyRewritePatternSet {
public:
PyRewritePatternSet(MlirContext ctx)
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
~PyRewritePatternSet() {
if (set.ptr)
mlirRewritePatternSetDestroy(set);
}

void add(MlirStringRef rootName, unsigned benefit,
const nb::callable &matchAndRewrite) {
MlirRewritePatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
nb::object res = f(op, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePattenCreate(
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(set, pattern);
}

PyFrozenRewritePatternSet freeze() {
MlirRewritePatternSet s = set;
set.ptr = nullptr;
return mlirFreezeRewritePattern(s);
}

private:
MlirRewritePatternSet set;
MlirContext ctx;
};

} // namespace

/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
"The current insertion point of the PatternRewriter.");
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
nb::
class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
"The current insertion point of the PatternRewriter.")
.def(
"replace_op",
[](PyPatternRewriter &self, MlirOperation op,
MlirOperation newOp) { self.replaceOp(op, newOp); },
"Replace an operation with a new operation.", nb::arg("op"),
nb::arg("new_op"),
// clang-format off
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
// clang-format on
)
.def(
"replace_op",
[](PyPatternRewriter &self, MlirOperation op,
const std::vector<MlirValue> &values) {
self.replaceOp(op, values);
},
"Replace an operation with a list of values.", nb::arg("op"),
nb::arg("values"),
// clang-format off
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
// clang-format on
)
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
nb::arg("op"),
// clang-format off
nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
// clang-format on
);

//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
//----------------------------------------------------------------------------
nb::class_<MlirRewritePattern>(m, "RewritePattern");
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
.def(
"__init__",
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
new (&self) PyRewritePatternSet(context.get()->get());
},
"context"_a = nb::none())
.def(
"add",
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
unsigned benefit) {
std::string opName =
nb::cast<std::string>(root.attr("OPERATION_NAME"));
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
fn);
},
"root"_a, "fn"_a, "benefit"_a = 1,
"Add a new rewrite pattern on the given root operation with the "
"callable as the matching and rewriting function and the given "
"benefit.")
.def("freeze", &PyRewritePatternSet::freeze,
"Freeze the pattern set into a frozen one.");

//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -237,7 +350,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"freeze",
[](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
Expand Down
93 changes: 78 additions & 15 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//

static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
}

static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
return {module};
}

static inline mlir::FrozenRewritePatternSet *
unwrap(MlirFrozenRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
Expand All @@ -290,15 +281,16 @@ wrap(mlir::FrozenRewritePatternSet *module) {
return {module};
}

MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
op.ptr = nullptr;
MlirFrozenRewritePatternSet
mlirFreezeRewritePattern(MlirRewritePatternSet set) {
auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
set.ptr = nullptr;
return wrap(m);
}

void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
delete unwrap(op);
op.ptr = nullptr;
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
delete unwrap(set);
set.ptr = nullptr;
}

MlirLogicalResult
Expand Down Expand Up @@ -332,6 +324,77 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}

//===----------------------------------------------------------------------===//
/// RewritePattern API
//===----------------------------------------------------------------------===//

namespace mlir {

class ExternalRewritePattern : public mlir::RewritePattern {
public:
ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
StringRef rootName, PatternBenefit benefit,
MLIRContext *context,
ArrayRef<StringRef> generatedNames)
: RewritePattern(rootName, benefit, context, generatedNames),
callbacks(callbacks), userData(userData) {
if (callbacks.construct)
callbacks.construct(userData);
}

~ExternalRewritePattern() {
if (callbacks.destruct)
callbacks.destruct(userData);
}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
return unwrap(callbacks.matchAndRewrite(
wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
wrap(&rewriter), userData));
}

private:
MlirRewritePatternCallbacks callbacks;
void *userData;
};

} // namespace mlir

MlirRewritePattern mlirOpRewritePattenCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames) {
std::vector<mlir::StringRef> generatedNamesVec;
generatedNamesVec.reserve(nGeneratedNames);
for (size_t i = 0; i < nGeneratedNames; ++i) {
generatedNamesVec.push_back(unwrap(generatedNames[i]));
}
return wrap(new mlir::ExternalRewritePattern(
callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
unwrap(context), generatedNamesVec));
}

//===----------------------------------------------------------------------===//
/// RewritePatternSet API
//===----------------------------------------------------------------------===//

MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
return wrap(new mlir::RewritePatternSet(unwrap(context)));
}

void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
delete unwrap(set);
}

void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
MlirRewritePattern pattern) {
std::unique_ptr<mlir::RewritePattern> patternPtr(
const_cast<mlir::RewritePattern *>(unwrap(pattern)));
pattern.ptr = nullptr;
unwrap(set)->add(std::move(patternPtr));
}

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
Loading