60 changes: 55 additions & 5 deletions mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class NodePrinter {
void printImpl(const EraseStmt *stmt);
void printImpl(const LetStmt *stmt);
void printImpl(const ReplaceStmt *stmt);
void printImpl(const ReturnStmt *stmt);
void printImpl(const RewriteStmt *stmt);

void printImpl(const AttributeExpr *expr);
void printImpl(const CallExpr *expr);
void printImpl(const DeclRefExpr *expr);
void printImpl(const MemberAccessExpr *expr);
void printImpl(const OperationExpr *expr);
Expand All @@ -89,11 +91,13 @@ class NodePrinter {
void printImpl(const OpConstraintDecl *decl);
void printImpl(const TypeConstraintDecl *decl);
void printImpl(const TypeRangeConstraintDecl *decl);
void printImpl(const UserConstraintDecl *decl);
void printImpl(const ValueConstraintDecl *decl);
void printImpl(const ValueRangeConstraintDecl *decl);
void printImpl(const NamedAttributeDecl *decl);
void printImpl(const OpNameDecl *decl);
void printImpl(const PatternDecl *decl);
void printImpl(const UserRewriteDecl *decl);
void printImpl(const VariableDecl *decl);
void printImpl(const Module *module);

Expand Down Expand Up @@ -135,6 +139,7 @@ void NodePrinter::print(Type type) {
print(type.getElementType());
os << "Range";
})
.Case([&](RewriteType) { os << "Rewrite"; })
.Case([&](TupleType type) {
os << "Tuple<";
llvm::interleaveComma(
Expand All @@ -160,17 +165,19 @@ void NodePrinter::print(const Node *node) {
.Case<
// Statements.
const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
const RewriteStmt,
const ReturnStmt, const RewriteStmt,

// Expressions.
const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
const OperationExpr, const TupleExpr, const TypeExpr,
const AttributeExpr, const CallExpr, const DeclRefExpr,
const MemberAccessExpr, const OperationExpr, const TupleExpr,
const TypeExpr,

// Decls.
const AttrConstraintDecl, const OpConstraintDecl,
const TypeConstraintDecl, const TypeRangeConstraintDecl,
const ValueConstraintDecl, const ValueRangeConstraintDecl,
const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
const UserConstraintDecl, const ValueConstraintDecl,
const ValueRangeConstraintDecl, const NamedAttributeDecl,
const OpNameDecl, const PatternDecl, const UserRewriteDecl,
const VariableDecl,

const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
Expand Down Expand Up @@ -199,6 +206,11 @@ void NodePrinter::printImpl(const ReplaceStmt *stmt) {
printChildren("ReplValues", stmt->getReplExprs());
}

void NodePrinter::printImpl(const ReturnStmt *stmt) {
os << "ReturnStmt " << stmt << "\n";
printChildren(stmt->getResultExpr());
}

void NodePrinter::printImpl(const RewriteStmt *stmt) {
os << "RewriteStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
Expand All @@ -208,6 +220,14 @@ void NodePrinter::printImpl(const AttributeExpr *expr) {
os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
}

void NodePrinter::printImpl(const CallExpr *expr) {
os << "CallExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getCallableExpr());
printChildren("Arguments", expr->getArguments());
}

void NodePrinter::printImpl(const DeclRefExpr *expr) {
os << "DeclRefExpr " << expr << " Type<";
print(expr->getType());
Expand Down Expand Up @@ -265,6 +285,21 @@ void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
os << "TypeRangeConstraintDecl " << decl << "\n";
}

void NodePrinter::printImpl(const UserConstraintDecl *decl) {
os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
<< "> ResultType<" << decl->getResultType() << ">";
if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
os << " Code<";
llvm::printEscapedString(*codeBlock, os);
os << ">";
}
os << "\n";
printChildren("Inputs", decl->getInputs());
printChildren("Results", decl->getResults());
if (const CompoundStmt *body = decl->getBody())
printChildren(body);
}

void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
os << "ValueConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
Expand Down Expand Up @@ -303,6 +338,21 @@ void NodePrinter::printImpl(const PatternDecl *decl) {
printChildren(decl->getBody());
}

void NodePrinter::printImpl(const UserRewriteDecl *decl) {
os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
<< "> ResultType<" << decl->getResultType() << ">";
if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
os << " Code<";
llvm::printEscapedString(*codeBlock, os);
os << ">";
}
os << "\n";
printChildren("Inputs", decl->getInputs());
printChildren("Results", decl->getResults());
if (const CompoundStmt *body = decl->getBody())
printChildren(body);
}

void NodePrinter::printImpl(const VariableDecl *decl) {
os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
<< "> Type<";
Expand Down
75 changes: 75 additions & 0 deletions mlir/lib/Tools/PDLL/AST/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
RewriteStmt(loc, rootOp, rewriteBody);
}

//===----------------------------------------------------------------------===//
// ReturnStmt
//===----------------------------------------------------------------------===//

ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
return new (ctx.getAllocator().Allocate<ReturnStmt>())
ReturnStmt(loc, resultExpr);
}

//===----------------------------------------------------------------------===//
// AttributeExpr
//===----------------------------------------------------------------------===//
Expand All @@ -118,6 +127,22 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
}

//===----------------------------------------------------------------------===//
// CallExpr
//===----------------------------------------------------------------------===//

CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
ArrayRef<Expr *> arguments, Type resultType) {
unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));

CallExpr *expr =
new (rawData) CallExpr(loc, resultType, callable, arguments.size());
std::uninitialized_copy(arguments.begin(), arguments.end(),
expr->getArguments().begin());
return expr;
}

//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -267,6 +292,30 @@ ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
ValueRangeConstraintDecl(loc, typeExpr);
}

//===----------------------------------------------------------------------===//
// UserConstraintDecl
//===----------------------------------------------------------------------===//

UserConstraintDecl *UserConstraintDecl::createImpl(
Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
const CompoundStmt *body, Type resultType) {
unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>(
inputs.size() + results.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
if (codeBlock)
codeBlock = codeBlock->copy(ctx.getAllocator());

UserConstraintDecl *decl = new (rawData) UserConstraintDecl(
name, inputs.size(), results.size(), codeBlock, body, resultType);
std::uninitialized_copy(inputs.begin(), inputs.end(),
decl->getInputs().begin());
std::uninitialized_copy(results.begin(), results.end(),
decl->getResults().begin());
return decl;
}

//===----------------------------------------------------------------------===//
// NamedAttributeDecl
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -300,6 +349,32 @@ PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
}

//===----------------------------------------------------------------------===//
// UserRewriteDecl
//===----------------------------------------------------------------------===//

UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results,
Optional<StringRef> codeBlock,
const CompoundStmt *body,
Type resultType) {
unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
inputs.size() + results.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
if (codeBlock)
codeBlock = codeBlock->copy(ctx.getAllocator());

UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
name, inputs.size(), results.size(), codeBlock, body, resultType);
std::uninitialized_copy(inputs.begin(), inputs.end(),
decl->getInputs().begin());
std::uninitialized_copy(results.begin(), results.end(),
decl->getResults().begin());
return decl;
}

//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Tools/PDLL/AST/TypeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> {
using Base::Base;
};

//===----------------------------------------------------------------------===//
// RewriteType
//===----------------------------------------------------------------------===//

struct RewriteTypeStorage : public TypeStorageBase<RewriteTypeStorage> {};

//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Tools/PDLL/AST/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ ValueRangeType ValueRangeType::get(Context &context) {
.cast<ValueRangeType>();
}

//===----------------------------------------------------------------------===//
// RewriteType
//===----------------------------------------------------------------------===//

RewriteType RewriteType::get(Context &context) {
return context.getTypeUniquer().get<ImplTy>();
}

//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ Token Lexer::lexIdentifier(const char *tokStart) {
.Case("OpName", Token::kw_OpName)
.Case("Pattern", Token::kw_Pattern)
.Case("replace", Token::kw_replace)
.Case("return", Token::kw_return)
.Case("rewrite", Token::kw_rewrite)
.Case("Rewrite", Token::kw_Rewrite)
.Case("type", Token::kw_type)
.Case("Type", Token::kw_Type)
.Case("TypeRange", Token::kw_TypeRange)
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class Token {
kw_OpName,
kw_Pattern,
kw_replace,
kw_return,
kw_rewrite,
kw_Rewrite,
kw_Type,
kw_TypeRange,
kw_Value,
Expand Down
872 changes: 812 additions & 60 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp

Large diffs are not rendered by default.

160 changes: 160 additions & 0 deletions mlir/test/mlir-pdll/Parser/constraint-failure.pdll
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s

//===----------------------------------------------------------------------===//
// Constraint Structure
//===----------------------------------------------------------------------===//

// CHECK: expected identifier name
Constraint {}

// -----

// CHECK: :6:12: error: `Foo` has already been defined
// CHECK: :5:12: note: see previous definition here
Constraint Foo() { op<>; }
Constraint Foo() { op<>; }

// -----

Constraint Foo() {
// CHECK: `erase` cannot be used within a Constraint
erase op<>;
}

// -----

Constraint Foo() {
// CHECK: `replace` cannot be used within a Constraint
replace;
}

// -----

Constraint Foo() {
// CHECK: `rewrite` cannot be used within a Constraint
rewrite;
}

// -----

Constraint Foo() -> Value {
// CHECK: `return` terminated the `Constraint` body, but found trailing statements afterwards
return _: Value;
return _: Value;
}

// -----

// CHECK: missing return in a `Constraint` expected to return `Value`
Constraint Foo() -> Value {
let value: Value;
}

// -----

// CHECK: expected `Constraint` lambda body to contain a single expression
Constraint Foo() -> Value => let foo: Value;

// -----

// CHECK: unable to convert expression of type `Op` to the expected type of `Attr`
Constraint Foo() -> Attr => op<>;

// -----

Rewrite SomeRewrite();

// CHECK: unable to invoke `Rewrite` within a match section
Constraint Foo() {
SomeRewrite();
}

// -----

Constraint Foo() {
Constraint Foo() {};
}

// -----

//===----------------------------------------------------------------------===//
// Arguments
//===----------------------------------------------------------------------===//

// CHECK: expected `(` to start argument list
Constraint Foo {}

// -----

// CHECK: expected identifier argument name
Constraint Foo(10{}

// -----

// CHECK: expected `:` before argument constraint
Constraint Foo(arg{}

// -----

// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
Constraint Foo(arg: Value<type>){}

// -----

// CHECK: expected `)` to end argument list
Constraint Foo(arg: Value{}

// -----

//===----------------------------------------------------------------------===//
// Results
//===----------------------------------------------------------------------===//

// CHECK: expected identifier constraint
Constraint Foo() -> {}

// -----

// CHECK: cannot create a single-element tuple with an element label
Constraint Foo() -> result: Value;

// -----

// CHECK: cannot create a single-element tuple with an element label
Constraint Foo() -> (result: Value);

// -----

// CHECK: expected identifier constraint
Constraint Foo() -> ();

// -----

// CHECK: expected `:` before result constraint
Constraint Foo() -> (result{};

// -----

// CHECK: expected `)` to end result list
Constraint Foo() -> (Op{};

// -----

// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
Constraint Foo() -> Value<type>){}

// -----

//===----------------------------------------------------------------------===//
// Native Constraints
//===----------------------------------------------------------------------===//

Pattern {
// CHECK: external declarations must be declared in global scope
Constraint ExternalConstraint();
}

// -----

// CHECK: expected `;` after native declaration
Constraint Foo() [{}]
74 changes: 74 additions & 0 deletions mlir/test/mlir-pdll/Parser/constraint.pdll
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s

// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<>>
Constraint Foo();

// -----

// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<>> Code< /* Native Code */ >
Constraint Foo() [{ /* Native Code */ }];

// -----

// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Value>
// CHECK: `Results`
// CHECK: `-VariableDecl {{.*}} Name<> Type<Value>
// CHECK: `-CompoundStmt {{.*}}
// CHECK: `-ReturnStmt {{.*}}
// CHECK: `-DeclRefExpr {{.*}} Type<Value>
// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Value>
Constraint Foo(arg: Value) -> Value => arg;

// -----

// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<result1: Value, result2: Attr>>
// CHECK: `Results`
// CHECK: |-VariableDecl {{.*}} Name<result1> Type<Value>
// CHECK: | `Constraints`
// CHECK: | `-ValueConstraintDecl {{.*}}
// CHECK: `-VariableDecl {{.*}} Name<result2> Type<Attr>
// CHECK: `Constraints`
// CHECK: `-AttrConstraintDecl {{.*}}
// CHECK: `-CompoundStmt {{.*}}
// CHECK: `-ReturnStmt {{.*}}
// CHECK: `-TupleExpr {{.*}} Type<Tuple<result1: Value, result2: Attr>>
// CHECK: |-MemberAccessExpr {{.*}} Member<0> Type<Value>
// CHECK: | `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
// CHECK: `-MemberAccessExpr {{.*}} Member<1> Type<Attr>
// CHECK: `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
Constraint Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">);

// -----

// CHECK: Module
// CHECK: |-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Value>
// CHECK: `Constraints`
// CHECK: `-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
// CHECK: `Results`
// CHECK: `-VariableDecl {{.*}} Name<> Type<Value>
// CHECK: `Constraints`
// CHECK: `-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
Constraint Bar(input: Value);

Constraint Foo(arg: Bar) -> Bar => arg;

// -----

// Test that anonymous constraints are uniquely named.

// CHECK: Module
// CHECK: UserConstraintDecl {{.*}} Name<<anonymous_constraint_0>> ResultType<Tuple<>>
// CHECK: UserConstraintDecl {{.*}} Name<<anonymous_constraint_1>> ResultType<Attr>
Constraint Outer() {
Constraint() {};
Constraint() => attr<"10">;
}
59 changes: 59 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,45 @@ Pattern {

// -----

//===----------------------------------------------------------------------===//
// Call Expr
//===----------------------------------------------------------------------===//

Constraint foo(value: Value);

Pattern {
// CHECK: expected `)` after argument list
foo(_: Value{};
}

// -----

Pattern {
// CHECK: expected a reference to a callable `Constraint` or `Rewrite`, but got: `Op`
let foo: Op;
foo();
}

// -----

Constraint Foo();

Pattern {
// CHECK: invalid number of arguments for constraint call; expected 0, but got 1
Foo(_: Value);
}

// -----

Constraint Foo(arg: Value);

Pattern {
// CHECK: unable to convert expression of type `Attr` to the expected type of `Value`
Foo(attr<"i32">);
}

// -----

//===----------------------------------------------------------------------===//
// Member Access Expr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -105,6 +144,26 @@ Pattern {

// -----

Constraint Foo();

Pattern {
// CHECK: unable to build a tuple with `Constraint` element
let tuple = (Foo);
erase op<>;
}

// -----

Rewrite Foo();

Pattern {
// CHECK: unable to build a tuple with `Rewrite` element
let tuple = (Foo);
erase op<>;
}

// -----

Pattern {
// CHECK: expected expression
let tuple = (10 = _: Value);
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,42 @@ Pattern {

// -----

//===----------------------------------------------------------------------===//
// CallExpr
//===----------------------------------------------------------------------===//

// CHECK: Module
// CHECK: |-UserConstraintDecl {{.*}} Name<MakeRootOp> ResultType<Op<my_dialect.foo>>
// CHECK: `-CallExpr {{.*}} Type<Op<my_dialect.foo>>
// CHECK: `-DeclRefExpr {{.*}} Type<Constraint>
// CHECK: `-UserConstraintDecl {{.*}} Name<MakeRootOp> ResultType<Op<my_dialect.foo>>
Constraint MakeRootOp() => op<my_dialect.foo>;

Pattern {
erase MakeRootOp();
}

// -----

// CHECK: Module
// CHECK: |-UserRewriteDecl {{.*}} Name<CreateNewOp> ResultType<Op<my_dialect.foo>>
// CHECK: `-PatternDecl {{.*}}
// CHECK: `-CallExpr {{.*}} Type<Op<my_dialect.foo>>
// CHECK: `-DeclRefExpr {{.*}} Type<Rewrite>
// CHECK: `-UserRewriteDecl {{.*}} Name<CreateNewOp> ResultType<Op<my_dialect.foo>>
// CHECK: `Arguments`
// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<ValueRange>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<my_dialect.bar>>
// CHECK: `-VariableDecl {{.*}} Name<inputOp> Type<Op<my_dialect.bar>>
Rewrite CreateNewOp(inputs: ValueRange) => op<my_dialect.foo>(inputs);

Pattern {
let inputOp = op<my_dialect.bar>;
replace op<my_dialect.bar>(inputOp) with CreateNewOp(inputOp);
}

// -----

//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
Expand Down
25 changes: 23 additions & 2 deletions mlir/test/mlir-pdll/Parser/pattern-failure.pdll
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s

// CHECK: expected `{` to start pattern body
// CHECK: expected `{` or `=>` to start pattern body
Pattern }

// -----
Expand All @@ -12,6 +12,13 @@ Pattern Foo { erase root: Op; }

// -----

// CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body
Pattern {
return _: Value;
}

// -----

// CHECK: expected Pattern body to terminate with an operation rewrite statement
Pattern {
let value: Value;
Expand All @@ -27,6 +34,20 @@ Pattern {

// -----

// CHECK: expected Pattern lambda body to contain a single operation rewrite statement, such as `erase`, `replace`, or `rewrite`
Pattern => op<>;

// -----

Rewrite SomeRewrite();

// CHECK: unable to invoke `Rewrite` within a match section
Pattern {
SomeRewrite();
}

// -----

//===----------------------------------------------------------------------===//
// Metadata
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/mlir-pdll/Parser/pattern.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ Pattern NamedPattern {
Pattern NamedPattern with benefit(10), recursion {
erase _: Op;
}

// -----

// CHECK: Module
// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
// CHECK: `-CompoundStmt
// CHECK: `-EraseStmt
Pattern NamedPattern => erase _: Op;
161 changes: 161 additions & 0 deletions mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s

//===----------------------------------------------------------------------===//
// Rewrite Structure
//===----------------------------------------------------------------------===//

// CHECK: expected identifier name
Rewrite {}

// -----

// CHECK: :6:9: error: `Foo` has already been defined
// CHECK: :5:9: note: see previous definition here
Rewrite Foo();
Rewrite Foo();

// -----

Rewrite Foo() -> Value {
// CHECK: `return` terminated the `Rewrite` body, but found trailing statements afterwards
return _: Value;
return _: Value;
}

// -----

// CHECK: missing return in a `Rewrite` expected to return `Value`
Rewrite Foo() -> Value {
let value: Value;
}

// -----

// CHECK: missing return in a `Rewrite` expected to return `Value`
Rewrite Foo() -> Value => erase op<my_dialect.foo>;

// -----

// CHECK: unable to convert expression of type `Op<my_dialect.foo>` to the expected type of `Attr`
Rewrite Foo() -> Attr => op<my_dialect.foo>;

// -----

// CHECK: expected `Rewrite` lambda body to contain a single expression or an operation rewrite statement; such as `erase`, `replace`, or `rewrite`
Rewrite Foo() => let foo = op<my_dialect.foo>;

// -----

Constraint ValueConstraint(value: Value);

// CHECK: unable to invoke `Constraint` within a rewrite section
Rewrite Foo(value: Value) {
ValueConstraint(value);
}

// -----

Rewrite Bar();

// CHECK: `Bar` has already been defined
Rewrite Foo() {
Rewrite Bar() {};
}

// -----

//===----------------------------------------------------------------------===//
// Arguments
//===----------------------------------------------------------------------===//

// CHECK: expected `(` to start argument list
Rewrite Foo {}

// -----

// CHECK: expected identifier argument name
Rewrite Foo(10{}

// -----

// CHECK: expected `:` before argument constraint
Rewrite Foo(arg{}

// -----

// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
Rewrite Foo(arg: Value<type>){}

// -----

Constraint ValueConstraint(value: Value);

// CHECK: arguments and results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
Rewrite Foo(arg: ValueConstraint);

// -----

// CHECK: expected `)` to end argument list
Rewrite Foo(arg: Value{}

// -----

//===----------------------------------------------------------------------===//
// Results
//===----------------------------------------------------------------------===//

// CHECK: expected identifier constraint
Rewrite Foo() -> {}

// -----

// CHECK: cannot create a single-element tuple with an element label
Rewrite Foo() -> result: Value;

// -----

// CHECK: cannot create a single-element tuple with an element label
Rewrite Foo() -> (result: Value);

// -----

// CHECK: expected identifier constraint
Rewrite Foo() -> ();

// -----

// CHECK: expected `:` before result constraint
Rewrite Foo() -> (result{};

// -----

// CHECK: expected `)` to end result list
Rewrite Foo() -> (Op{};

// -----

// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
Rewrite Foo() -> Value<type>){}

// -----

Constraint ValueConstraint(value: Value);

// CHECK: results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
Rewrite Foo() -> ValueConstraint;

// -----

//===----------------------------------------------------------------------===//
// Native Rewrites
//===----------------------------------------------------------------------===//

Pattern {
// CHECK: external declarations must be declared in global scope
Rewrite ExternalConstraint();
}

// -----

// CHECK: expected `;` after native declaration
Rewrite Foo() [{}]
58 changes: 58 additions & 0 deletions mlir/test/mlir-pdll/Parser/rewrite.pdll
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s

// CHECK: Module
// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<>>
Rewrite Foo();

// -----

// CHECK: Module
// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<>> Code< /* Native Code */ >
Rewrite Foo() [{ /* Native Code */ }];

// -----

// CHECK: Module
// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Value>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Op>
// CHECK: `Results`
// CHECK: `-VariableDecl {{.*}} Name<> Type<Value>
// CHECK: `-CompoundStmt {{.*}}
// CHECK: `-ReturnStmt {{.*}}
// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<Value>
// CHECK: `-DeclRefExpr {{.*}} Type<Op>
// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Op>
Rewrite Foo(arg: Op) -> Value => arg;

// -----

// CHECK: Module
// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<result1: Value, result2: Attr>>
// CHECK: `Results`
// CHECK: |-VariableDecl {{.*}} Name<result1> Type<Value>
// CHECK: | `Constraints`
// CHECK: | `-ValueConstraintDecl {{.*}}
// CHECK: `-VariableDecl {{.*}} Name<result2> Type<Attr>
// CHECK: `Constraints`
// CHECK: `-AttrConstraintDecl {{.*}}
// CHECK: `-CompoundStmt {{.*}}
// CHECK: `-ReturnStmt {{.*}}
// CHECK: `-TupleExpr {{.*}} Type<Tuple<result1: Value, result2: Attr>>
// CHECK: |-MemberAccessExpr {{.*}} Member<0> Type<Value>
// CHECK: | `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
// CHECK: `-MemberAccessExpr {{.*}} Member<1> Type<Attr>
// CHECK: `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
Rewrite Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">);

// -----

// Test that anonymous Rewrites are uniquely named.

// CHECK: Module
// CHECK: UserRewriteDecl {{.*}} Name<<anonymous_rewrite_0>> ResultType<Tuple<>>
// CHECK: UserRewriteDecl {{.*}} Name<<anonymous_rewrite_1>> ResultType<Attr>
Rewrite Outer() {
Rewrite() {};
Rewrite() => attr<"10">;
}
47 changes: 47 additions & 0 deletions mlir/test/mlir-pdll/Parser/stmt-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,33 @@ Pattern {

// -----

Constraint Foo();

Pattern {
// CHECK: unable to define variable of `Constraint` type
let foo = Foo;
}

// -----

Rewrite Foo();

Pattern {
// CHECK: unable to define variable of `Rewrite` type
let foo = Foo;
}

// -----

Constraint MultiConstraint(arg1: Value, arg2: Value);

Pattern {
// CHECK: `Constraint`s applied via a variable constraint list must take a single input, but got 2
let foo: MultiConstraint;
}

// -----

//===----------------------------------------------------------------------===//
// `replace`
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -276,6 +303,17 @@ Pattern {

// -----

//===----------------------------------------------------------------------===//
// `return`
//===----------------------------------------------------------------------===//

// CHECK: expected `;` after statement
Constraint Foo(arg: Value) -> Value {
return arg
}

// -----

//===----------------------------------------------------------------------===//
// `rewrite`
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -307,3 +345,12 @@ Pattern {
op<>;
};
}

// -----

Pattern {
// CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body
rewrite root: Op with {
return root;
};
}