Skip to content

Commit

Permalink
[mlir:PDLL] Rework the C++ generation of native Constraint/Rewrite ar…
Browse files Browse the repository at this point in the history
…guments and results

The current translation uses the old "ugly"/"raw" form which used PDLValue for the arguments
and results. This commit updates the C++ generation to use the recently added sugar that
allows for directly using the desired types for the arguments and result of PDL functions.
In addition, this commit also properly imports the C++ class for ODS operations, constraints,
and interfaces. This allows for a much more convienent C++ API than previously granted
with the raw/low-level types.

Differential Revision: https://reviews.llvm.org/D124817
  • Loading branch information
River707 committed May 31, 2022
1 parent 0429472 commit 1c2edb0
Show file tree
Hide file tree
Showing 18 changed files with 415 additions and 167 deletions.
178 changes: 156 additions & 22 deletions mlir/docs/PDLL.md
Expand Up @@ -1146,17 +1146,86 @@ Pattern {
```

The arguments of the constraint are accessible within the code block via the
same name. The type of these native variables are mapped directly to the
corresponding MLIR type of the [core constraint](#core-constraints) used. For
example, an `Op` corresponds to a variable of type `Operation *`.
same name. See the ["type translation"](#native-constraint-type-translations) below for
detailed information on how PDLL types are converted to native types. In addition to the
PDLL arguments, the code block may also access the current `PatternRewriter` using
`rewriter`. The result type of the native constraint function is implicitly defined
as a `::mlir::LogicalResult`.

The results of the constraint can be populated using the provided `results`
variable. This variable is a `PDLResultList`, and expects results to be
populated in the order that they are defined within the result list of the
constraint declaration.
Taking the constraints defined above as an example, these function would roughly be
translated into:

In addition to the above, the code block may also access the current
`PatternRewriter` using `rewriter`.
```c++
LogicalResult HasOneUse(PatternRewriter &rewriter, Value value) {
return success(value.hasOneUse());
}
LogicalResult HasSameElementType(Value value1, Value value2) {
return success(value1.getType().cast<ShapedType>().getElementType() ==
value2.getType().cast<ShapedType>().getElementType());
}
```
TODO: Native constraints should also be allowed to return values in certain cases.
###### Native Constraint Type Translations
The types of argument and result variables are generally mapped to the corresponding
MLIR type of the [constraint](#constraints) used. Below is a detailed description
of how the mapped type of a variable is determined for the various different types of
constraints.
* Attr, Op, Type, TypeRange, Value, ValueRange:
These are all core constraints, and are mapped directly to the MLIR equivalent
(that their names suggest), namely:
* `Attr` -> "::mlir::Attribute"
* `Op` -> "::mlir::Operation *"
* `Type` -> "::mlir::Type"
* `TypeRange` -> "::mlir::TypeRange"
* `Value` -> "::mlir::Value"
* `ValueRange` -> "::mlir::ValueRange"
* Op<dialect.name>
A named operation constraint has a unique translation. If the ODS registration of the
referenced operation has been included, the qualified C++ is used. If the ODS information
is not available, this constraint maps to "::mlir::Operation *", similarly to the unnamed
variant. For example, given the following:
```pdll
// `my_ops.td` provides the ODS definition of the `my_dialect` operations, such as
// `my_dialect.bar` used below.
#include "my_ops.td"
Constraint Cst(op: Op<my_dialect.bar>) [{
return success(op ... );
}];
```

The native type used for `op` may be of the form `my_dialect::BarOp`, as opposed to the
default `::mlir::Operation *`. Below is a sample translation of the above constraint:

```c++
LogicalResult Cst(my_dialect::BarOp op) {
return success(op ... );
}
```
* Imported ODS Constraints
Aside from the core constraints, certain constraints imported from ODS may use a unique
native type. How to enable this unique type depends on the ODS constraint construct that
was imported:
* `Attr` constraints
- Imported `Attr` constraints utilize the `storageType` field for native type translation.
* `Type` constraints
- Imported `Type` constraints utilize the `cppClassName` field for native type translation.
* `AttrInterface`/`OpInterface`/`TypeInterface` constraints
- Imported interfaces utilize the `cppClassName` field for native type translation.
#### Defining Constraints Inline
Expand Down Expand Up @@ -1414,10 +1483,7 @@ be defined by specifying a string code block after the rewrite declaration:
```pdll
Rewrite BuildOp(value: Value) -> (foo: Op<my_dialect.foo>, bar: Op<my_dialect.bar>) [{
// We push back the results into the `results` variable in the order defined
// by the result list of the rewrite declaration.
results.push_back(rewriter.create<my_dialect::FooOp>(value));
results.push_back(rewriter.create<my_dialect::BarOp>());
return {rewriter.create<my_dialect::FooOp>(value), rewriter.create<my_dialect::BarOp>()};
}];
Pattern {
Expand All @@ -1431,17 +1497,85 @@ Pattern {
```

The arguments of the rewrite are accessible within the code block via the
same name. The type of these native variables are mapped directly to the
corresponding MLIR type of the [core constraint](#core-constraints) used. For
example, an `Op` corresponds to a variable of type `Operation *`.
same name. See the ["type translation"](#native-rewrite-type-translations) below for
detailed information on how PDLL types are converted to native types. In addition to the
PDLL arguments, the code block may also access the current `PatternRewriter` using
`rewriter`. See the ["result translation"](#native-rewrite-result-translation) section
for detailed information on how the result type of the native function is determined.

Taking the rewrite defined above as an example, this function would roughly be
translated into:

```c++
std::tuple<my_dialect::FooOp, my_dialect::BarOp> BuildOp(Value value) {
return {rewriter.create<my_dialect::FooOp>(value), rewriter.create<my_dialect::BarOp>()};
}
```
The results of the rewrite can be populated using the provided `results`
variable. This variable is a `PDLResultList`, and expects results to be
populated in the order that they are defined within the result list of the
rewrite declaration.
###### Native Rewrite Type Translations
In addition to the above, the code block may also access the current
`PatternRewriter` using `rewriter`.
The types of argument and result variables are generally mapped to the corresponding
MLIR type of the [constraint](#constraints) used. The rules of native `Rewrite` type translation
are identical to those of native `Constraint`s, please view the corresponding
[native `Constraint` type translation](#native-constraint-type-translations) section for a
detailed description of how the mapped type of a variable is determined.
###### Native Rewrite Result Translation
The results of a native rewrite are directly translated to the results of the native function,
using the type translation rules [described above](#native-rewrite-type-translations). The section
below describes the various result translation scenarios:
* Zero Result
```pdll
Rewrite createOp() [{
rewriter.create<my_dialect::FooOp>();
}];
```

In the case where a native `Rewrite` has no results, the native function returns `void`:

```c++
void createOp(PatternRewriter &rewriter) {
rewriter.create<my_dialect::FooOp>();
}
```
* Single Result
```pdll
Rewrite createOp() -> Op<my_dialect.foo> [{
return rewriter.create<my_dialect::FooOp>();
}];
```

In the case where a native `Rewrite` has a single result, the native function returns the corresponding
native type for that single result:

```c++
my_dialect::FooOp createOp(PatternRewriter &rewriter) {
return rewriter.create<my_dialect::FooOp>();
}
```
* Multi Result
```pdll
Rewrite complexRewrite(value: Value) -> (Op<my_dialect.foo>, FunctionOpInterface) [{
...
}];
```

In the case where a native `Rewrite` has multiple results, the native function returns a `std::tuple<...>`
containing the corresponding native types for each of the results:

```c++
std::tuple<my_dialect::FooOp, FunctionOpInterface>
complexRewrite(PatternRewriter &rewriter, Value value) {
...
}
```
#### Defining Rewrites Inline
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Expand Up @@ -943,9 +943,13 @@ struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
" to be of type: " + llvm::getTypeName<T>());
});
}
using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;

static T processAsArg(BaseT baseValue) {
return baseValue.template cast<T>();
}
using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;

static void processAsResult(PatternRewriter &, PDLResultList &results,
T value) {
results.push_back(value);
Expand All @@ -967,6 +971,8 @@ template <>
struct ProcessPDLValue<StringRef>
: public ProcessPDLValueBasedOn<StringRef, StringAttr> {
static StringRef processAsArg(StringAttr value) { return value.getValue(); }
using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;

static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
StringRef value) {
results.push_back(rewriter.getStringAttr(value));
Expand Down
53 changes: 41 additions & 12 deletions mlir/include/mlir/Tools/PDLL/AST/Nodes.h
Expand Up @@ -506,6 +506,7 @@ class OperationExpr final
NamedAttributeDecl *> {
public:
static OperationExpr *create(Context &ctx, SMRange loc,
const ods::Operation *odsOp,
const OpNameDecl *nameDecl,
ArrayRef<Expr *> operands,
ArrayRef<Expr *> resultTypes,
Expand Down Expand Up @@ -830,16 +831,15 @@ class ValueRangeConstraintDecl
/// - This is a constraint which is defined using only PDLL constructs.
class UserConstraintDecl final
: public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
llvm::TrailingObjects<UserConstraintDecl, VariableDecl *> {
llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
public:
/// Create a native constraint with the given optional code block.
static UserConstraintDecl *createNative(Context &ctx, const Name &name,
ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results,
Optional<StringRef> codeBlock,
Type resultType) {
return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
resultType);
static UserConstraintDecl *
createNative(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
Type resultType, ArrayRef<StringRef> nativeInputTypes = {}) {
return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
/*body=*/nullptr, resultType);
}

/// Create a PDLL constraint with the given body.
Expand All @@ -848,8 +848,8 @@ class UserConstraintDecl final
ArrayRef<VariableDecl *> results,
const CompoundStmt *body,
Type resultType) {
return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
body, resultType);
return createImpl(ctx, name, inputs, /*nativeInputTypes=*/llvm::None,
results, /*codeBlock=*/llvm::None, body, resultType);
}

/// Return the name of the constraint.
Expand All @@ -863,6 +863,10 @@ class UserConstraintDecl final
return const_cast<UserConstraintDecl *>(this)->getInputs();
}

/// Return the explicit native type to use for the given input. Returns None
/// if no explicit type was set.
Optional<StringRef> getNativeInputType(unsigned index) const;

/// Return the explicit results of the constraint declaration. May be empty,
/// even if the constraint has results (e.g. in the case of inferred results).
MutableArrayRef<VariableDecl *> getResults() {
Expand Down Expand Up @@ -891,10 +895,12 @@ class UserConstraintDecl final
/// components.
static UserConstraintDecl *
createImpl(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
ArrayRef<StringRef> nativeInputTypes,
ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
const CompoundStmt *body, Type resultType);

UserConstraintDecl(const Name &name, unsigned numInputs, unsigned numResults,
UserConstraintDecl(const Name &name, unsigned numInputs,
bool hasNativeInputTypes, unsigned numResults,
Optional<StringRef> codeBlock, const CompoundStmt *body,
Type resultType)
: Base(name.getLoc(), &name), numInputs(numInputs),
Expand All @@ -916,8 +922,14 @@ class UserConstraintDecl final
/// The result type of the constraint.
Type resultType;

/// Flag indicating if this constraint has explicit native input types.
bool hasNativeInputTypes;

/// Allow access to various internals.
friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *>;
friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
return numInputs + numResults;
}
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1145,6 +1157,23 @@ class CallableDecl : public Decl {
return cast<UserRewriteDecl>(this)->getResultType();
}

/// Return the explicit results of the declaration. Note that these may be
/// empty, even if the callable has results (e.g. in the case of inferred
/// results).
ArrayRef<VariableDecl *> getResults() const {
if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
return cst->getResults();
return cast<UserRewriteDecl>(this)->getResults();
}

/// Return the optional code block of this callable, if this is a native
/// callable with a provided implementation.
Optional<StringRef> getCodeBlock() const {
if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
return cst->getCodeBlock();
return cast<UserRewriteDecl>(this)->getCodeBlock();
}

/// Support LLVM type casting facilities.
static bool classof(const Node *decl) {
return isa<UserConstraintDecl, UserRewriteDecl>(decl);
Expand Down
11 changes: 10 additions & 1 deletion mlir/include/mlir/Tools/PDLL/AST/Types.h
Expand Up @@ -14,6 +14,10 @@

namespace mlir {
namespace pdll {
namespace ods {
class Operation;
} // namespace ods

namespace ast {
class Context;

Expand Down Expand Up @@ -151,10 +155,15 @@ class OperationType : public Type::TypeBase<detail::OperationTypeStorage> {
/// Return an instance of the Operation type with an optional operation name.
/// If no name is provided, this type may refer to any operation.
static OperationType get(Context &context,
Optional<StringRef> name = llvm::None);
Optional<StringRef> name = llvm::None,
const ods::Operation *odsOp = nullptr);

/// Return the name of this operation type, or None if it doesn't have on.
Optional<StringRef> getName() const;

/// Return the ODS operation that this type refers to, or nullptr if the ODS
/// operation is unknown.
const ods::Operation *getODSOperation() const;
};

//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Tools/PDLL/ODS/Context.h
Expand Up @@ -63,7 +63,8 @@ class Context {
/// operation already existed).
std::pair<Operation *, bool>
insertOperation(StringRef name, StringRef summary, StringRef desc,
bool supportsResultTypeInferrence, SMLoc loc);
StringRef nativeClassName, bool supportsResultTypeInferrence,
SMLoc loc);

/// Lookup an operation registered with the given name, or null if no
/// operation with that name is registered.
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
Expand Up @@ -35,7 +35,8 @@ class Dialect {
/// operation already existed).
std::pair<Operation *, bool>
insertOperation(StringRef name, StringRef summary, StringRef desc,
bool supportsResultTypeInferrence, SMLoc loc);
StringRef nativeClassName, bool supportsResultTypeInferrence,
SMLoc loc);

/// Lookup an operation registered with the given name, or null if no
/// operation with that name is registered.
Expand Down

0 comments on commit 1c2edb0

Please sign in to comment.