diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 49b2dae62dc22..52069278c5eea 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -433,8 +433,9 @@ class SymbolInfoMap { DagAndConstant(node.getAsOpaquePointer(), operandIndex, variadicSubIndex)); } - static SymbolInfo getResult(const Operator *op) { - return SymbolInfo(op, Kind::Result, std::nullopt); + static SymbolInfo getResult(const Operator *op, int index) { + return SymbolInfo(op, Kind::Result, + DagAndConstant(nullptr, index, std::nullopt)); } static SymbolInfo getValue() { return SymbolInfo(nullptr, Kind::Value, std::nullopt); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 1a1a58ad271bb..b0dd0451e05e8 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -370,6 +370,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( case Kind::Result: { // If `index` is greater than zero, then we are referencing a specific // result of a multi-result op. The result can still be variadic. + if (index < 0) + index = dagAndConstant->operandIndexOrNumValues; if (index >= 0) { std::string v = std::string(formatv("{0}.getODSResults({1})", name, index)); @@ -442,6 +444,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( return std::string(repl); } case Kind::Result: { + if (index < 0) + index = dagAndConstant->operandIndexOrNumValues; if (index >= 0) { auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); @@ -522,8 +526,10 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, } bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { - std::string name = getValuePackName(symbol).str(); - auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); + int index = -1; + StringRef name = getValuePackName(symbol, &index); + auto inserted = + symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index)); return symbolInfoMap.count(inserted->first) == 1; } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 5564264ed8b0b..4bd68a5801bac 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1660,6 +1660,16 @@ def OneResultOp3 : TEST_Op<"one_result3"> { let results = (outs I32:$result1); } +def OneResultOp4 : TEST_Op<"one_result4"> { + let arguments = (ins F32); + let results = (outs F32); +} + +def TwoResultOp2 : TEST_Op<"two_result2"> { + let arguments = (ins); + let results = (outs F32, F32); +} + // Test using multi-result op as a whole def : Pat<(ThreeResultOp MultiResultOpKind1:$kind), (AnotherThreeResultOp $kind)>; @@ -1696,6 +1706,12 @@ def : Pattern< (AnotherTwoResultOp $kind) ]>; +// Test referencing a one-param op whose +// param comes from the first result of a two-result op. +def : Pat< + (OneResultOp4 (TwoResultOp2:$a__1)), + (replaceWithValue $a__0)>; + //===----------------------------------------------------------------------===// // Test Patterns (Variadic Ops) //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index bd55338618eec..cedf528fb8717 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -594,6 +594,25 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar return } +// CHECK-LABEL: @replaceOneResultWithNSuffixArgMatch +func.func @replaceOneResultWithNSuffixArgMatch() -> (f32) { + // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32) + %0:2 = "test.two_result2"() : () -> (f32, f32) + %1 = "test.one_result4"(%0#1) : (f32) -> (f32) + // CHECK: return %0#0 : f32 + return %1 : f32 +} + +// CHECK-LABEL: @replaceOneResultWithNSuffixArgNoMatch +func.func @replaceOneResultWithNSuffixArgNoMatch() -> (f32) { + // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32) + %0:2 = "test.two_result2"() : () -> (f32, f32) + // CHECK: %1 = "test.one_result4"(%0#0) : (f32) -> f32 + %1 = "test.one_result4"(%0#0) : (f32) -> (f32) + // CHECK: return %1 : f32 + return %1 : f32 +} + //===----------------------------------------------------------------------===// // Test patterns that operate on properties //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 605033daa719f..75cf7232e729c 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -615,10 +615,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { op.getQualCppClassName())); // If the operand's name is set, set to that variable. - auto name = tree.getSymbol(); + int index = -1; + auto name = SymbolInfoMap::getValuePackName(tree.getSymbol(), &index).str(); if (!name.empty()) os << formatv("{0} = {1};\n", name, castedName); + if (index != -1) { + emitMatchCheck(opName, formatv("(resultNumber{0} == 1)", depth), + formatv("\"{0} does not come from result number {1} type\"", + castedName, index)); + } + for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i, ++opArgIdx) { auto opArg = op.getArg(opArgIdx); @@ -662,6 +669,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { "auto *{0} = " "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", argName, castedName, nextOperand); + os.indent() << formatv( + "[[maybe_unused]] auto resultNumber{0} = " + "::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin()" + ")).getResultNumber();\n", + depth + 1, castedName, nextOperand); // Null check of operand's definingOp emitMatchCheck( castedName, /*matchStr=*/argName,