diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 43a0bdaf86cf3..1d0849e479d37 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1872,6 +1872,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> { let results = (outs I32:$output); } +def TestEitherOpC : TEST_Op<"either_op_c"> { + let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1); + let results = (outs I32:$output); +} + def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x), (TestEitherOpB $arg2, $x)>; @@ -1883,6 +1888,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_), $x), (TestEitherOpB $arg2, $x)>; +def : Pat<(TestEitherOpC ConstantAttr, (either $arg1, I32:$arg2)), + (TestEitherOpB $arg1, $arg2)>; + def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> { let arguments = (ins I32:$arg0); let results = (outs I32:$output); diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 90905280c0796..27598fb63a6c8 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -609,8 +609,8 @@ func.func @redundantTest(%arg0: i32) -> i32 { // Test either directive //===----------------------------------------------------------------------===// -// CHECK: @either_dag_leaf_only -func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { +// CHECK-LABEL: @eitherDagLeafOnly +func.func @eitherDagLeafOnly(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32 // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 @@ -618,8 +618,8 @@ func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { return } -// CHECK: @either_dag_leaf_dag_node -func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { +// CHECK-LABEL: @eitherDagLeafDagNode +func.func @eitherDagLeafDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32 // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32 @@ -628,8 +628,8 @@ func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () return } -// CHECK: @either_dag_node_dag_node -func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { +// CHECK-LABEL: @eitherDagNodeDagNode +func.func @eitherDagNodeDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32 %1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32 // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 @@ -639,10 +639,22 @@ func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () return } +// CHECK-LABEL: @testEitherOpWithAttr +func.func @testEitherOpWithAttr(%arg0 : i32, %arg1 : i16) -> () { + // CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32 + %0 = "test.either_op_c"(%arg0, %arg1) {attr = 0 : i32} : (i32, i16) -> i32 + // CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32 + %1 = "test.either_op_c"(%arg1, %arg0) {attr = 0 : i32} : (i16, i32) -> i32 + // CHECK: "test.either_op_c"(%arg0, %arg1) <{attr = 1 : i32}> : (i32, i16) -> i32 + %2 = "test.either_op_c"(%arg0, %arg1) {attr = 1 : i32} : (i32, i16) -> i32 + return +} + //===----------------------------------------------------------------------===// // Test that ops without type deduction can be created with type builders. //===----------------------------------------------------------------------===// +// CHECK-LABEL: @explicitReturnTypeTest func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 { %0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8 // CHECK: "test.op_x"(%arg0) : (i64) -> i32 @@ -650,6 +662,7 @@ func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 { return %0 : i8 } +// CHECK-LABEL: @returnTypeBuilderTest func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 { %0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8 // CHECK: "test.op_x"(%arg0) : (i1) -> i1 @@ -657,6 +670,7 @@ func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 { return %0 : i8 } +// CHECK-LABEL: @multipleReturnTypeBuildTest func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 { %0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1 // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32) @@ -666,6 +680,7 @@ func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 { return %0 : i1 } +// CHECK-LABEL: @copyValueType func.func @copyValueType(%arg0 : i8) -> i32 { %0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32 // CHECK: "test.op_x"(%arg0) : (i8) -> i8 @@ -673,6 +688,7 @@ func.func @copyValueType(%arg0 : i8) -> i32 { return %0 : i32 } +// CHECK-LABEL: @multipleReturnTypeDifferent func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 { %0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64 // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64) @@ -684,6 +700,7 @@ func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 { // Test that multiple trailing directives can be mixed in patterns. //===----------------------------------------------------------------------===// +// CHECK-LABEL: @returnTypeAndLocation func.func @returnTypeAndLocation(%arg0 : i32) -> i1 { %0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1 // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1") @@ -696,6 +713,7 @@ func.func @returnTypeAndLocation(%arg0 : i32) -> i1 { // Test that patterns can create ConstantStrAttr //===----------------------------------------------------------------------===// +// CHECK-LABEL: @testConstantStrAttr func.func @testConstantStrAttr() -> () { // CHECK: test.has_str_value {value = "foo"} test.no_str_value {value = "bar"} @@ -706,6 +724,7 @@ func.func @testConstantStrAttr() -> () { // Test that patterns with variadics propagate sizes //===----------------------------------------------------------------------===// +// CHECK-LABEL: @testVariadic func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64, %crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () { // CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array}> : (i64, f32, f32, f32, f32, i32, i32) -> () diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 58abcc2bee895..75721c89793b5 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { if (isa(opArg)) { auto operandName = formatv("{0}.getODSOperands({1})", castedName, nextOperand); - emitOperandMatch(tree, castedName, operandName.str(), opArgIdx, + emitOperandMatch(tree, castedName, operandName.str(), nextOperand, /*operandMatcher=*/tree.getArgAsLeaf(i), /*argName=*/tree.getArgName(i), opArgIdx, /*variadicSubIndex=*/std::nullopt); @@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, int argIndex, std::optional variadicSubIndex) { Operator &op = tree.getDialectOp(opMap); - auto *operand = cast(op.getArg(operandIndex)); + NamedTypeConstraint operand = op.getOperand(operandIndex); // If a constraint is specified, we need to generate C++ statements to // check the constraint. @@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, // Only need to verify if the matcher's type is different from the one // of op definition. Constraint constraint = operandMatcher.getAsConstraint(); - if (operand->constraint != constraint) { - if (operand->isVariableLength()) { + if (operand.constraint != constraint) { + if (operand.isVariableLength()) { auto error = formatv( "further constrain op {0}'s variadic operand #{1} unsupported now", op.getOperationName(), argIndex); @@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, verifier, opName, self.str(), formatv( "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"", - operand - op.operand_begin(), op.getOperationName(), + operandIndex, op.getOperationName(), escapeString(constraint.getSummary())) .str()); } @@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, // Capture the value // `$_` is a special symbol to ignore op argument matching. if (!argName.empty() && argName != "_") { - auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex, + auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex, variadicSubIndex); if (res == symbolInfoMap.end()) PrintFatalError(loc, formatv("symbol not found: {0}", argName)); @@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree, StringRef variadicTreeName = variadicArgTree.getSymbol(); if (!variadicTreeName.empty()) { auto res = - symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex, + symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex, /*variadicSubIndex=*/std::nullopt); if (res == symbolInfoMap.end()) PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));