Skip to content

Commit

Permalink
[mlir:PDLL] Don't require users to provide operands/results when all …
Browse files Browse the repository at this point in the history
…are variadic

When all operands or results are variadic, zero values is a perfectly valid behavior
to expect, and we shouldn't force the user to provide values in this case. For example,
when creating a call or a return operation we often don't want/need to provide return
values.

Differential Revision: https://reviews.llvm.org/D133721
  • Loading branch information
River707 committed Nov 8, 2022
1 parent 9e57210 commit ec92a12
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 16 deletions.
68 changes: 55 additions & 13 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Expand Up @@ -426,23 +426,23 @@ class Parser {
FailureOr<ast::OperationExpr *>
createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
OpResultTypeContext resultTypeContext,
MutableArrayRef<ast::Expr *> operands,
SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
MutableArrayRef<ast::Expr *> results);
SmallVectorImpl<ast::Expr *> &results);
LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> operands);
SmallVectorImpl<ast::Expr *> &operands);
LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> results);
SmallVectorImpl<ast::Expr *> &results);
void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
const ods::Operation *odsOp);
LogicalResult validateOperationOperandsOrResults(
StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
ast::Type rangeTy);
ast::RangeType rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames);
Expand Down Expand Up @@ -2851,9 +2851,9 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
SMRange loc, const ast::OpNameDecl *name,
OpResultTypeContext resultTypeContext,
MutableArrayRef<ast::Expr *> operands,
SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
MutableArrayRef<ast::Expr *> results) {
SmallVectorImpl<ast::Expr *> &results) {
Optional<StringRef> opNameRef = name->getName();
const ods::Operation *odsOp = lookupODSOperation(opNameRef);

Expand Down Expand Up @@ -2896,7 +2896,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
LogicalResult
Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> operands) {
SmallVectorImpl<ast::Expr *> &operands) {
return validateOperationOperandsOrResults(
"operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
Expand All @@ -2906,7 +2906,7 @@ Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
LogicalResult
Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> results) {
SmallVectorImpl<ast::Expr *> &results) {
return validateOperationOperandsOrResults(
"result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
Expand Down Expand Up @@ -2956,9 +2956,9 @@ void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,

LogicalResult Parser::validateOperationOperandsOrResults(
StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
ast::Type rangeTy) {
ast::RangeType rangeTy) {
// All operation types accept a single range parameter.
if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy)))
Expand All @@ -2969,14 +2969,56 @@ LogicalResult Parser::validateOperationOperandsOrResults(
/// If the operation has ODS information, we can more accurately verify the
/// values.
if (odsOpLoc) {
if (odsValues.size() != values.size()) {
auto emitSizeMismatchError = [&] {
return emitErrorAndNote(
loc,
llvm::formatv("invalid number of {0} groups for `{1}`; expected "
"{2}, but got {3}",
groupName, *name, odsValues.size(), values.size()),
*odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
};

// Handle the case where no values were provided.
if (values.empty()) {
// If we don't expect any on the ODS side, we are done.
if (odsValues.empty())
return success();

// If we do, check if we actually need to provide values (i.e. if any of
// the values are actually required).
unsigned numVariadic = 0;
for (const auto &odsValue : odsValues) {
if (!odsValue.isVariableLength())
return emitSizeMismatchError();
++numVariadic;
}

// If we are in a non-rewrite context, we don't need to do anything more.
// Zero-values is a valid constraint on the operation.
if (parserContext != ParserContext::Rewrite)
return success();

// Otherwise, when in a rewrite we may need to provide values to match the
// ODS signature of the operation to create.

// If we only have one variadic value, just use an empty list.
if (numVariadic == 1)
return success();

// Otherwise, create dummy values for each of the entries so that we
// adhere to the ODS signature.
for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
values.push_back(
ast::RangeExpr::create(ctx, loc, /*elements=*/llvm::None, rangeTy));
}
return success();
}

// Verify that the number of values provided matches the number of value
// groups ODS expects.
if (odsValues.size() != values.size())
return emitSizeMismatchError();

auto diagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
*odsOpLoc);
Expand Down
5 changes: 2 additions & 3 deletions mlir/test/lib/Transforms/TestDialectConversion.pdll
Expand Up @@ -10,9 +10,8 @@
#include "mlir/Transforms/DialectConversion.pdll"

/// Change the result type of a producer.
// FIXME: We shouldn't need to specify arguments for the result cast.
Pattern => replace op<test.cast>(args: ValueRange) -> (results: TypeRange)
with op<test.cast>(args) -> (convertTypes(results));
Pattern => replace op<test.cast> -> (results: TypeRange)
with op<test.cast> -> (convertTypes(results));

/// Pass through test.return conversion.
Pattern => replace op<test.return>(args: ValueRange)
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Expand Up @@ -213,6 +213,34 @@ Pattern {

// -----

// Test that we don't need to provide values if all elements
// are optional.

#include "include/ops.td"

// CHECK: Module
// CHECK: -OperationExpr {{.*}} Type<Op<test.multi_variadic>>
// CHECK-NOT: `Operands`
// CHECK-NOT: `Result Types`
// CHECK: -OperationExpr {{.*}} Type<Op<test.all_variadic>>
// CHECK-NOT: `Operands`
// CHECK-NOT: `Result Types`
// CHECK: -OperationExpr {{.*}} Type<Op<test.multi_variadic>>
// CHECK: `Operands`
// CHECK: -RangeExpr {{.*}} Type<ValueRange>
// CHECK: -RangeExpr {{.*}} Type<ValueRange>
// CHECK: `Result Types`
// CHECK: -RangeExpr {{.*}} Type<TypeRange>
// CHECK: -RangeExpr {{.*}} Type<TypeRange>
Pattern {
rewrite op<test.multi_variadic>() -> () with {
op<test.all_variadic> -> ();
op<test.multi_variadic> -> ();
};
}

// -----

//===----------------------------------------------------------------------===//
// TupleExpr
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/mlir-pdll/Parser/include/ops.td
Expand Up @@ -28,3 +28,8 @@ def OpAllVariadic : Op<Test_Dialect, "all_variadic"> {
def OpMultipleSingleResult : Op<Test_Dialect, "multiple_single_result"> {
let results = (outs I64:$result, I64:$result2);
}

def OpMultiVariadic : Op<Test_Dialect, "multi_variadic"> {
let arguments = (ins Variadic<I64>:$operands, Variadic<I64>:$operand2);
let results = (outs Variadic<I64>:$results, Variadic<I64>:$results2);
}

0 comments on commit ec92a12

Please sign in to comment.