Skip to content
Merged
52 changes: 52 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);

//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
Expand Down Expand Up @@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
DEFINE_C_API_STRUCT(MlirPDLValue, const void);
DEFINE_C_API_STRUCT(MlirPDLResultList, void);

MLIR_CAPI_EXPORTED MlirPDLPatternModule
mlirPDLPatternModuleFromModule(MlirModule op);
Expand All @@ -323,6 +326,55 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);

MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);

/// Cast the MlirPDLValue to an MlirValue.
/// Return a null value if the cast fails, just like llvm::dyn_cast.
MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);

/// Cast the MlirPDLValue to an MlirType.
/// Return a null value if the cast fails, just like llvm::dyn_cast.
MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);

/// Cast the MlirPDLValue to an MlirOperation.
/// Return a null value if the cast fails, just like llvm::dyn_cast.
MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);

/// Cast the MlirPDLValue to an MlirAttribute.
/// Return a null value if the cast fails, just like llvm::dyn_cast.
MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);

/// Push the MlirValue into the given MlirPDLResultList.
MLIR_CAPI_EXPORTED void
mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);

/// Push the MlirType into the given MlirPDLResultList.
MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
MlirType value);

/// Push the MlirOperation into the given MlirPDLResultList.
MLIR_CAPI_EXPORTED void
mlirPDLResultListPushBackOperation(MlirPDLResultList results,
MlirOperation value);

/// Push the MlirAttribute into the given MlirPDLResultList.
MLIR_CAPI_EXPORTED void
mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
MlirAttribute value);

/// This function type is used as callbacks for PDL native rewrite functions.
/// Input values can be accessed by `values` with its size `nValues`;
/// output values can be added into `results` by `mlirPDLResultListPushBack*`
/// APIs. And the return value indicates whether the rewrite succeeds.
typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
MlirPDLValue *values, void *userData);

/// Register a rewrite function into the given PDL pattern module.
/// `userData` will be provided as an argument to the rewrite function.
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData);

#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

#undef DEFINE_C_API_STRUCT
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/PDLPatternMatch.h.inc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one random q: why is this file named .h.inc? @jpienaar?

Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public:
/// value is not an instance of `T`.
template <typename T,
typename ResultT = std::conditional_t<
std::is_convertible<T, bool>::value, T, std::optional<T>>>
std::is_constructible_v<bool, T>, T, std::optional<T>>>
Copy link
Member Author

@PragmaTwice PragmaTwice Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some classes like mlir::Value, mlir::Type .. have an explicit operator bool(), and is_convertible will ignore explicit conversions. So here we replace it with is_constructible_v to make it work for these types, so that mlir::Value can be used instead of weird std::optional<mlir::Value> (since mlir::Value is nullable itself).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand this one - also is this needed for this PR?

Copy link
Member Author

@PragmaTwice PragmaTwice Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Due to this review suggestion #159926 (comment), now we use dyn_cast instead of isa and cast of PDLValue. And this change is to address issues in PDLValue::dyn_cast. (previously we did't need this change : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change the API back to the origin form if we want to avoid touching that header. Maybe @jpienaar has idea for this : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah it's fine I just didn't see that comment about dyn_cast

ResultT dyn_cast() const {
return isa<T>() ? castImpl<T>() : ResultT();
}
Expand Down
103 changes: 98 additions & 5 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
#include "Rewrite.h"

#include "IRModule.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"

namespace nb = nanobind;
using namespace mlir;
Expand All @@ -24,6 +27,31 @@ using namespace mlir::python;
namespace {

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
return nb::cast(v);
if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
return nb::cast(v);
if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
return nb::cast(v);
if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
return nb::cast(v);

throw std::runtime_error("unsupported PDL value type");
}

// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
// Note that None is considered success.
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
if (obj.is_none())
return mlirLogicalResultSuccess();

return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
: mlirLogicalResultSuccess();
}

/// Owning Wrapper around a PDLPatternModule.
class PyPDLPatternModule {
public:
Expand All @@ -38,6 +66,23 @@ class PyPDLPatternModule {
}
MlirPDLPatternModule get() { return module; }

void registerRewriteFunction(const std::string &name,
const nb::callable &fn) {
mlirPDLPatternModuleRegisterRewriteFunction(
get(), mlirStringRefCreate(name.data(), name.size()),
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
std::vector<nb::object> args;
args.reserve(nValues);
for (size_t i = 0; i < nValues; ++i)
args.push_back(objectFromPDLValue(values[i]));
return logicalResultFromObject(f(rewriter, results, args));
},
fn.ptr());
}

private:
MlirPDLPatternModule module;
};
Expand Down Expand Up @@ -78,10 +123,48 @@ class PyFrozenRewritePatternSet {

/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<MlirPDLResultList>(m, "PDLResultList")
.def(
"append",
[](MlirPDLResultList results, const PyValue &value) {
mlirPDLResultListPushBackValue(results, value);
},
// clang-format off
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
// clang-format on
)
.def(
"append",
[](MlirPDLResultList results, const PyOperation &op) {
mlirPDLResultListPushBackOperation(results, op);
},
// clang-format off
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
// clang-format on
)
.def(
"append",
[](MlirPDLResultList results, const PyType &type) {
mlirPDLResultListPushBackType(results, type);
},
// clang-format off
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
// clang-format on
)
.def(
"append",
[](MlirPDLResultList results, const PyAttribute &attr) {
mlirPDLResultListPushBackAttribute(results, attr);
},
// clang-format off
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
// clang-format on
);
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
Expand All @@ -93,10 +176,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
// clang-format on
"module"_a, "Create a PDL module from the given module.")
.def("freeze", [](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
});
.def(
"freeze",
[](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
.def(
"register_rewrite_function",
[](PyPDLPatternModule &self, const std::string &name,
const nb::callable &fn) {
self.registerRewriteFunction(name, fn);
},
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
Expand Down
84 changes: 84 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "mlir/CAPI/Rewrite.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}

//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//

inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
assert(rewriter.ptr && "unexpected null rewriter");
return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
}

inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
return {rewriter};
}

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -331,4 +346,73 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
op.ptr = nullptr;
return wrap(m);
}

inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
assert(value.ptr && "unexpected null PDL value");
return static_cast<const mlir::PDLValue *>(value.ptr);
}

inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }

inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
assert(results.ptr && "unexpected null PDL results");
return static_cast<mlir::PDLResultList *>(results.ptr);
}

inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
return {results};
}

MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Value>());
}

MlirType mlirPDLValueAsType(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Type>());
}

MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Operation *>());
}

MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Attribute>());
}

void mlirPDLResultListPushBackValue(MlirPDLResultList results,
MlirValue value) {
unwrap(results)->push_back(unwrap(value));
}

void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
unwrap(results)->push_back(unwrap(value));
}

void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
MlirOperation value) {
unwrap(results)->push_back(unwrap(value));
}

void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
MlirAttribute value) {
unwrap(results)->push_back(unwrap(value));
}

void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData) {
unwrap(pdlModule)->registerRewriteFunction(
unwrap(name),
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> values) -> LogicalResult {
std::vector<MlirPDLValue> mlirValues;
mlirValues.reserve(values.size());
for (auto &value : values) {
mlirValues.push_back(wrap(&value));
}
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
mlirValues.size(), mlirValues.data(),
userData));
});
}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
Loading