From 736f47d03eccbd5afcf76a72c9cf923ac0098452 Mon Sep 17 00:00:00 2001 From: Martin Coll Date: Thu, 18 Sep 2025 20:49:43 +0000 Subject: [PATCH 1/2] Add support to TableGen source patterns to match multi-result values by index --- mlir/include/mlir/TableGen/Pattern.h | 5 +++-- mlir/lib/TableGen/Pattern.cpp | 9 +++++++-- mlir/test/lib/Dialect/Test/TestOps.td | 16 ++++++++++++++++ mlir/test/mlir-tblgen/pattern.mlir | 19 +++++++++++++++++++ mlir/tools/mlir-tblgen/RewriterGen.cpp | 13 ++++++++++++- 5 files changed, 57 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 49b2dae62dc22..52069278c5eea 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -433,8 +433,9 @@ class SymbolInfoMap { DagAndConstant(node.getAsOpaquePointer(), operandIndex, variadicSubIndex)); } - static SymbolInfo getResult(const Operator *op) { - return SymbolInfo(op, Kind::Result, std::nullopt); + static SymbolInfo getResult(const Operator *op, int index) { + return SymbolInfo(op, Kind::Result, + DagAndConstant(nullptr, index, std::nullopt)); } static SymbolInfo getValue() { return SymbolInfo(nullptr, Kind::Value, std::nullopt); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 1a1a58ad271bb..38725050cefe8 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -370,6 +370,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( case Kind::Result: { // If `index` is greater than zero, then we are referencing a specific // result of a multi-result op. The result can still be variadic. + if (index < 0) + index = dagAndConstant->operandIndexOrNumValues; if (index >= 0) { std::string v = std::string(formatv("{0}.getODSResults({1})", name, index)); @@ -442,6 +444,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( return std::string(repl); } case Kind::Result: { + if (index < 0) + index = dagAndConstant->operandIndexOrNumValues; if (index >= 0) { auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); @@ -522,8 +526,9 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, } bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { - std::string name = getValuePackName(symbol).str(); - auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); + int index = -1; + StringRef name = getValuePackName(symbol, &index); + auto inserted = symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index)); return symbolInfoMap.count(inserted->first) == 1; } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 5564264ed8b0b..4bd68a5801bac 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1660,6 +1660,16 @@ def OneResultOp3 : TEST_Op<"one_result3"> { let results = (outs I32:$result1); } +def OneResultOp4 : TEST_Op<"one_result4"> { + let arguments = (ins F32); + let results = (outs F32); +} + +def TwoResultOp2 : TEST_Op<"two_result2"> { + let arguments = (ins); + let results = (outs F32, F32); +} + // Test using multi-result op as a whole def : Pat<(ThreeResultOp MultiResultOpKind1:$kind), (AnotherThreeResultOp $kind)>; @@ -1696,6 +1706,12 @@ def : Pattern< (AnotherTwoResultOp $kind) ]>; +// Test referencing a one-param op whose +// param comes from the first result of a two-result op. +def : Pat< + (OneResultOp4 (TwoResultOp2:$a__1)), + (replaceWithValue $a__0)>; + //===----------------------------------------------------------------------===// // Test Patterns (Variadic Ops) //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index bd55338618eec..cedf528fb8717 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -594,6 +594,25 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar return } +// CHECK-LABEL: @replaceOneResultWithNSuffixArgMatch +func.func @replaceOneResultWithNSuffixArgMatch() -> (f32) { + // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32) + %0:2 = "test.two_result2"() : () -> (f32, f32) + %1 = "test.one_result4"(%0#1) : (f32) -> (f32) + // CHECK: return %0#0 : f32 + return %1 : f32 +} + +// CHECK-LABEL: @replaceOneResultWithNSuffixArgNoMatch +func.func @replaceOneResultWithNSuffixArgNoMatch() -> (f32) { + // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32) + %0:2 = "test.two_result2"() : () -> (f32, f32) + // CHECK: %1 = "test.one_result4"(%0#0) : (f32) -> f32 + %1 = "test.one_result4"(%0#0) : (f32) -> (f32) + // CHECK: return %1 : f32 + return %1 : f32 +} + //===----------------------------------------------------------------------===// // Test patterns that operate on properties //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 605033daa719f..312f5174b7b9e 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -615,10 +615,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { op.getQualCppClassName())); // If the operand's name is set, set to that variable. - auto name = tree.getSymbol(); + int index = -1; + auto name = SymbolInfoMap::getValuePackName(tree.getSymbol(), &index).str(); if (!name.empty()) os << formatv("{0} = {1};\n", name, castedName); + if (index != -1) { + emitMatchCheck(opName, + formatv("(resultNumber{0} == 1)", depth), + formatv("\"{0} does not come from result number {1} type\"", castedName, index)); + } + for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i, ++opArgIdx) { auto opArg = op.getArg(opArgIdx); @@ -662,6 +669,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { "auto *{0} = " "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", argName, castedName, nextOperand); + os.indent() << formatv( + "[[maybe_unused]] auto resultNumber{0} = " + "::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin())).getResultNumber();\n", + depth + 1, castedName, nextOperand); // Null check of operand's definingOp emitMatchCheck( castedName, /*matchStr=*/argName, From 854dbd4a8430ed406366d6a4553e186a3a1287a9 Mon Sep 17 00:00:00 2001 From: Martin Coll Date: Fri, 19 Sep 2025 14:40:39 +0000 Subject: [PATCH 2/2] Code formatting --- mlir/lib/TableGen/Pattern.cpp | 3 ++- mlir/tools/mlir-tblgen/RewriterGen.cpp | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 38725050cefe8..b0dd0451e05e8 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -528,7 +528,8 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { int index = -1; StringRef name = getValuePackName(symbol, &index); - auto inserted = symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index)); + auto inserted = + symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index)); return symbolInfoMap.count(inserted->first) == 1; } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 312f5174b7b9e..75cf7232e729c 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -621,9 +621,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { os << formatv("{0} = {1};\n", name, castedName); if (index != -1) { - emitMatchCheck(opName, - formatv("(resultNumber{0} == 1)", depth), - formatv("\"{0} does not come from result number {1} type\"", castedName, index)); + emitMatchCheck(opName, formatv("(resultNumber{0} == 1)", depth), + formatv("\"{0} does not come from result number {1} type\"", + castedName, index)); } for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; @@ -669,9 +669,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { "auto *{0} = " "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", argName, castedName, nextOperand); - os.indent() << formatv( + os.indent() << formatv( "[[maybe_unused]] auto resultNumber{0} = " - "::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin())).getResultNumber();\n", + "::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin()" + ")).getResultNumber();\n", depth + 1, castedName, nextOperand); // Null check of operand's definingOp emitMatchCheck(