Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData);

/// This function type is used as callbacks for PDL native constraint functions.
/// Input values can be accessed by `values` with its size `nValues`;
/// output values can be added into `results` by `mlirPDLResultListPushBack*`
/// APIs. And the return value indicates whether the constraint holds.
typedef MlirLogicalResult (*MlirPDLConstraintFunction)(
MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
MlirPDLValue *values, void *userData);

/// Register a constraint function into the given PDL pattern module.
/// `userData` will be provided as an argument to the constraint function.
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLConstraintFunction constraintFn, void *userData);

#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

#undef DEFINE_C_API_STRUCT
Expand Down
37 changes: 32 additions & 5 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ static nb::object objectFromPDLValue(MlirPDLValue value) {
throw std::runtime_error("unsupported PDL value type");
}

static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
MlirPDLValue *values) {
std::vector<nb::object> args;
args.reserve(nValues);
for (size_t i = 0; i < nValues; ++i)
args.push_back(objectFromPDLValue(values[i]));
return args;
}

// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
Expand Down Expand Up @@ -74,11 +83,22 @@ class PyPDLPatternModule {
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
std::vector<nb::object> args;
args.reserve(nValues);
for (size_t i = 0; i < nValues; ++i)
args.push_back(objectFromPDLValue(values[i]));
return logicalResultFromObject(f(rewriter, results, args));
return logicalResultFromObject(
f(rewriter, results, objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}

void registerConstraintFunction(const std::string &name,
const nb::callable &fn) {
mlirPDLPatternModuleRegisterConstraintFunction(
get(), mlirStringRefCreate(name.data(), name.size()),
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(rewriter, results, objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
Expand Down Expand Up @@ -199,6 +219,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
const nb::callable &fn) {
self.registerRewriteFunction(name, fn);
},
nb::keep_alive<1, 3>())
.def(
"register_constraint_function",
[](PyPDLPatternModule &self, const std::string &name,
const nb::callable &fn) {
self.registerConstraintFunction(name, fn);
},
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
Expand Down
30 changes: 25 additions & 5 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,21 +398,41 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
unwrap(results)->push_back(unwrap(value));
}

inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
std::vector<MlirPDLValue> mlirValues;
mlirValues.reserve(values.size());
for (auto &value : values) {
mlirValues.push_back(wrap(&value));
}
return mlirValues;
}

void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData) {
unwrap(pdlModule)->registerRewriteFunction(
unwrap(name),
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> values) -> LogicalResult {
std::vector<MlirPDLValue> mlirValues;
mlirValues.reserve(values.size());
for (auto &value : values) {
mlirValues.push_back(wrap(&value));
}
std::vector<MlirPDLValue> mlirValues = wrap(values);
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
mlirValues.size(), mlirValues.data(),
userData));
});
}

void mlirPDLPatternModuleRegisterConstraintFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLConstraintFunction constraintFn, void *userData) {
unwrap(pdlModule)->registerConstraintFunction(
unwrap(name),
[userData, constraintFn](PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> values) -> LogicalResult {
std::vector<MlirPDLValue> mlirValues = wrap(values);
return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
mlirValues.size(), mlirValues.data(),
userData));
});
}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
56 changes: 56 additions & 0 deletions mlir/test/python/integration/dialects/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,43 @@ def rew():
)
pdl.ReplaceOp(op0, with_op=newOp)

@pdl.pattern(benefit=1, sym_name="myint_add_zero_fold")
def pat():
t = pdl.TypeOp(i32)
v0 = pdl.OperandOp()
v1 = pdl.OperandOp()
v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1])
op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])

@pdl.rewrite()
def rew():
pdl.ReplaceOp(op0, with_values=[v])

def add_fold(rewriter, results, values):
a0, a1 = values
results.append(IntegerAttr.get(i32, a0.value + a1.value))

def is_zero(value):
op = value.owner
if isinstance(op, Operation):
return op.name == "myint.constant" and op.attributes["value"].value == 0
return False

# Check if either operand is a constant zero,
# and append the other operand to the results if so.
def has_zero(rewriter, results, values):
v0, v1 = values
if is_zero(v0):
results.append(v1)
return False
if is_zero(v1):
results.append(v0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I haven't used this pdl API before: are results used for anything? Are they meant to be? Because I'm pretty sure currently (the current state of the PR) they're not and it'd be impossible to make it work (since you're wrapping in a std:: vector)?

Copy link
Member Author

@PragmaTwice PragmaTwice Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh sorry there's no type annotation so easy to get wrong here. results here is typed MlirPDLResultList and we expose an append method for this type (this is the only method of this type for now). And values is typed std::vector<nb::object> so it is a list of Value/Attribute/Type.. .

The results can be used in the pdl.apply_native_constraint, for example

%res = pdl.apply_native_constraint("some_constraint", %v1: pdl.value, %v2: pdl.value) -> pdl.value

Then the callable passed with some_constraint will be called like (pseudocode, the call actually happens in C++):

values = [v1, v2] # corresponding to argument %v1 and %v2
if not some_constraint(rewriter, results, values): # if it succeeds
   assert(len(results) == 1) # results[0] corresponding to %res

And here for has_zero, the story is:

  1. we check if either operand is zero
  2. if no zero, just fail and exit
  3. otherwise, we push the other (non-zero) operand into results and use it as the new op for rewrite

e.g. for x + 0:

  1. zero (the second operand) found! (the constraint holds)
  2. push x (the first operand) to results
  3. x + 0 rewritten to x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I understand now - I thought MlirPDLResultList was just an ArrayRef but I see it's not, it's actually a container that holds a bunch of SmallVectors.

return False
return True

pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)
pdl_module.register_constraint_function("has_zero", has_zero)
return pdl_module.freeze()


Expand All @@ -181,3 +212,28 @@ def test_pdl_register_function(module_):
apply_patterns_and_fold_greedily(module_, frozen)

return module_


# CHECK-LABEL: TEST: test_pdl_register_function_constraint
# CHECK: return %arg0 : i32
@construct_and_print_in_module
def test_pdl_register_function_constraint(module_):
load_myint_dialect()

module_ = Module.parse(
"""
func.func @f(%x : i32) -> i32 {
%c0 = "myint.constant"() { value = 1 }: () -> (i32)
%c1 = "myint.constant"() { value = -1 }: () -> (i32)
%a = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
%b = "myint.add"(%a, %x): (i32, i32) -> (i32)
%c = "myint.add"(%b, %a): (i32, i32) -> (i32)
func.return %c : i32
}
"""
)

frozen = get_pdl_pattern_fold()
apply_patterns_and_fold_greedily(module_, frozen)

return module_