diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index a6c8ebbe767055..5264b953aaabcd 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -106,6 +106,10 @@ class Parser { FailureOr parseTopLevelDecl(); FailureOr parseNamedAttributeDecl(); + FailureOr + parseLambdaBody(function_ref processStatementFn, + bool expectTerminalSemicolon = true); + FailureOr parsePatternLambdaBody(); FailureOr parsePatternDecl(); LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); @@ -547,6 +551,36 @@ FailureOr Parser::parseNamedAttributeDecl() { return ast::NamedAttributeDecl::create(ctx, name, attrValue); } +FailureOr Parser::parseLambdaBody( + function_ref processStatementFn, + bool expectTerminalSemicolon) { + consumeToken(Token::equal_arrow); + + // Parse the single statement of the lambda body. + SMLoc bodyStartLoc = curToken.getStartLoc(); + pushDeclScope(); + FailureOr singleStatement = parseStmt(expectTerminalSemicolon); + bool failedToParse = + failed(singleStatement) || failed(processStatementFn(*singleStatement)); + popDeclScope(); + if (failedToParse) + return failure(); + + SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); + return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); +} + +FailureOr Parser::parsePatternLambdaBody() { + return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { + if (isa(statement)) + return success(); + return emitError( + statement->getLoc(), + "expected Pattern lambda body to contain a single operation " + "rewrite statement, such as `erase`, `replace`, or `rewrite`"); + }); +} + FailureOr Parser::parsePatternDecl() { SMRange loc = curToken.getLoc(); consumeToken(Token::kw_Pattern); @@ -568,29 +602,37 @@ FailureOr Parser::parsePatternDecl() { // Parse the pattern body. ast::CompoundStmt *body; - if (curToken.isNot(Token::l_brace)) - return emitError("expected `{` to start pattern body"); - FailureOr bodyResult = parseCompoundStmt(); - if (failed(bodyResult)) - return failure(); - body = *bodyResult; - - // Verify the body of the pattern. - auto bodyIt = body->begin(), bodyE = body->end(); - for (; bodyIt != bodyE; ++bodyIt) { - // Break when we've found the rewrite statement. - if (isa(*bodyIt)) - break; - } - if (bodyIt == bodyE) { - return emitError(loc, - "expected Pattern body to terminate with an operation " - "rewrite statement, such as `erase`"); - } - if (std::next(bodyIt) != bodyE) { - return emitError((*std::next(bodyIt))->getLoc(), - "Pattern body was terminated by an operation " - "rewrite statement, but found trailing statements"); + // Handle a lambda body. + if (curToken.is(Token::equal_arrow)) { + FailureOr bodyResult = parsePatternLambdaBody(); + if (failed(bodyResult)) + return failure(); + body = *bodyResult; + } else { + if (curToken.isNot(Token::l_brace)) + return emitError("expected `{` or `=>` to start pattern body"); + FailureOr bodyResult = parseCompoundStmt(); + if (failed(bodyResult)) + return failure(); + body = *bodyResult; + + // Verify the body of the pattern. + auto bodyIt = body->begin(), bodyE = body->end(); + for (; bodyIt != bodyE; ++bodyIt) { + // Break when we've found the rewrite statement. + if (isa(*bodyIt)) + break; + } + if (bodyIt == bodyE) { + return emitError(loc, + "expected Pattern body to terminate with an operation " + "rewrite statement, such as `erase`"); + } + if (std::next(bodyIt) != bodyE) { + return emitError((*std::next(bodyIt))->getLoc(), + "Pattern body was terminated by an operation " + "rewrite statement, but found trailing statements"); + } } return createPatternDecl(loc, name, metadata, body); diff --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll index caa084cda0b68e..42ea6578211205 100644 --- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll @@ -1,6 +1,6 @@ // RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s -// CHECK: expected `{` to start pattern body +// CHECK: expected `{` or `=>` to start pattern body Pattern } // ----- @@ -27,6 +27,11 @@ Pattern { // ----- +// CHECK: expected Pattern lambda body to contain a single operation rewrite statement, such as `erase`, `replace`, or `rewrite` +Pattern => op<>; + +// ----- + //===----------------------------------------------------------------------===// // Metadata //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/pattern.pdll b/mlir/test/mlir-pdll/Parser/pattern.pdll index 1a7851606213ec..f0b2046e4b1b84 100644 --- a/mlir/test/mlir-pdll/Parser/pattern.pdll +++ b/mlir/test/mlir-pdll/Parser/pattern.pdll @@ -23,3 +23,11 @@ Pattern NamedPattern { Pattern NamedPattern with benefit(10), recursion { erase _: Op; } + +// ----- + +// CHECK: Module +// CHECK: `-PatternDecl {{.*}} Name +// CHECK: `-CompoundStmt +// CHECK: `-EraseStmt +Pattern NamedPattern => erase _: Op;