Skip to content

Commit ea64828

Browse files
committed
[mlir:PDL] Expand how native constraint/rewrite functions can be defined
This commit refactors the expected form of native constraint and rewrite functions, and greatly reduces the necessary user complexity required when defining a native function. Namely, this commit adds in automatic processing of the necessary PDLValue glue code, and allows for users to define constraint/rewrite functions using the C++ types that they actually want to use. As an example, lets see a simple example rewrite defined today: ``` static void rewriteFn(PatternRewriter &rewriter, PDLResultList &results, ArrayRef<PDLValue> args) { ValueRange operandValues = args[0].cast<ValueRange>(); TypeRange typeValues = args[1].cast<TypeRange>(); ... // Create an operation at some point and pass it back to PDL. Operation *op = rewriter.create<SomeOp>(...); results.push_back(op); } ``` After this commit, that same rewrite could be defined as: ``` static Operation *rewriteFn(PatternRewriter &rewriter ValueRange operandValues, TypeRange typeValues) { ... // Create an operation at some point and pass it back to PDL. return rewriter.create<SomeOp>(...); } ``` Differential Revision: https://reviews.llvm.org/D122086
1 parent f5e48a2 commit ea64828

File tree

13 files changed

+776
-309
lines changed

13 files changed

+776
-309
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct function_traits<ReturnType (ClassType::*)(Args...) const, false> {
129129
/// Overload for class function types.
130130
template <typename ClassType, typename ReturnType, typename... Args>
131131
struct function_traits<ReturnType (ClassType::*)(Args...), false>
132-
: function_traits<ReturnType (ClassType::*)(Args...) const> {};
132+
: public function_traits<ReturnType (ClassType::*)(Args...) const> {};
133133
/// Overload for non-class function types.
134134
template <typename ReturnType, typename... Args>
135135
struct function_traits<ReturnType (*)(Args...), false> {
@@ -143,6 +143,9 @@ struct function_traits<ReturnType (*)(Args...), false> {
143143
template <size_t i>
144144
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
145145
};
146+
template <typename ReturnType, typename... Args>
147+
struct function_traits<ReturnType (*const)(Args...), false>
148+
: public function_traits<ReturnType (*)(Args...)> {};
146149
/// Overload for non-class function type references.
147150
template <typename ReturnType, typename... Args>
148151
struct function_traits<ReturnType (&)(Args...), false>

mlir/docs/PDLL.md

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,17 +1006,11 @@ External constraints are those registered explicitly with the `RewritePatternSet
10061006
the C++ PDL API. For example, the constraints above may be registered as:
10071007

10081008
```c++
1009-
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
1010-
static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) {
1011-
Value value = pdlValue.cast<Value>();
1012-
1009+
static LogicalResult hasOneUseImpl(PatternRewriter &rewriter, Value value) {
10131010
return success(value.hasOneUse());
10141011
}
1015-
static LogicalResult hasSameElementTypeImpl(ArrayRef<PDLValue> pdlValues,
1016-
PatternRewriter &rewriter) {
1017-
Value value1 = pdlValues[0].cast<Value>();
1018-
Value value2 = pdlValues[1].cast<Value>();
1019-
1012+
static LogicalResult hasSameElementTypeImpl(PatternRewriter &rewriter,
1013+
Value value1, Value Value2) {
10201014
return success(value1.getType().cast<ShapedType>().getElementType() ==
10211015
value2.getType().cast<ShapedType>().getElementType());
10221016
}
@@ -1307,14 +1301,10 @@ External rewrites are those registered explicitly with the `RewritePatternSet` v
13071301
the C++ PDL API. For example, the rewrite above may be registered as:
13081302

13091303
```c++
1310-
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
1311-
static void buildOpImpl(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
1312-
PDLResultList &results) {
1313-
Value value = args[0].cast<Value>();
1314-
1304+
static Operation *buildOpImpl(PDLResultList &results, Value value) {
13151305
// insert special rewrite logic here.
13161306
Operation *resultOp = ...;
1317-
results.push_back(resultOp);
1307+
return resultOp;
13181308
}
13191309

13201310
void registerNativeRewrite(RewritePatternSet &patterns) {

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,14 @@ def PDL_ApplyNativeRewriteOp
6868

6969
```mlir
7070
// Apply a native rewrite method that returns an attribute.
71-
%ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute
71+
%ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %attr1) : !pdl.attribute
7272
```
7373

7474
```c++
7575
// The native rewrite as defined in C++:
76-
static void myNativeFunc(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
77-
PDLResultList &results) {
78-
Value arg0 = args[0].cast<Value>();
79-
Value arg1 = args[1].cast<Value>();
80-
81-
// Just push back the first param attribute.
82-
results.push_back(param0);
76+
static Attribute myNativeFunc(PatternRewriter &rewriter, Value arg0, Attribute arg1) {
77+
// Just return the second arg.
78+
return arg1;
8379
}
8480

8581
void registerNativeRewrite(PDLPatternModule &pdlModule) {

mlir/include/mlir/IR/Builders.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,8 @@ class OpBuilder : public Builder {
409409

410410
/// Creates an operation with the given fields.
411411
Operation *create(Location loc, StringAttr opName, ValueRange operands,
412-
TypeRange types, ArrayRef<NamedAttribute> attributes = {},
412+
TypeRange types = {},
413+
ArrayRef<NamedAttribute> attributes = {},
413414
BlockRange successors = {},
414415
MutableArrayRef<std::unique_ptr<Region>> regions = {});
415416

0 commit comments

Comments
 (0)