Skip to content

Conversation

colltoaction
Copy link

The changes proposed in this pull request enhance the matching capabilities of the TableGen language. It extends the "__N suffix" approach, only previously available in replacement patterns.

The new logic is exercised in the following TableGen pattern. Notice OneResultOp4 references a two-result operation, and multi-result operations previously never matched. With these changes, the "__1" suffix means this pattern will now match when the parameter received is #1 from TwoResultOp2, as seen in the MLIR example below.

def : Pat<
  (OneResultOp4 (TwoResultOp2:$a__1)),
  (replaceWithValue $a__0)>;
%0:2 = "test.two_result2"() : () -> (f32, f32)
%1 = "test.one_result4"(%0#1) : (f32) -> (f32)
return %1 : f32

CC @jpienaar following up after our Discord conversation. Hope you find this is a good addition! Thanks in advance for reviewing my code.

Thank you!

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Martin Coll (colltoaction)

Changes

The changes proposed in this pull request enhance the matching capabilities of the TableGen language. It extends the "__N suffix" approach, only previously available in replacement patterns.

The new logic is exercised in the following TableGen pattern. Notice OneResultOp4 references a two-result operation, and multi-result operations previously never matched. With these changes, the "__1" suffix means this pattern will now match when the parameter received is #<!-- -->1 from TwoResultOp2, as seen in the MLIR example below.

def : Pat&lt;
  (OneResultOp4 (TwoResultOp2:$a__1)),
  (replaceWithValue $a__0)&gt;;
%0:2 = "test.two_result2"() : () -&gt; (f32, f32)
%1 = "test.one_result4"(%0#<!-- -->1) : (f32) -&gt; (f32)
return %1 : f32

CC @jpienaar following up after our Discord conversation. Hope you find this is a good addition! Thanks in advance for reviewing my code.

Thank you!


Full diff: https://github.com/llvm/llvm-project/pull/159656.diff

5 Files Affected:

  • (modified) mlir/include/mlir/TableGen/Pattern.h (+3-2)
  • (modified) mlir/lib/TableGen/Pattern.cpp (+7-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+16)
  • (modified) mlir/test/mlir-tblgen/pattern.mlir (+19)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+12-1)
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 49b2dae62dc22..52069278c5eea 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -433,8 +433,9 @@ class SymbolInfoMap {
                         DagAndConstant(node.getAsOpaquePointer(), operandIndex,
                                        variadicSubIndex));
     }
-    static SymbolInfo getResult(const Operator *op) {
-      return SymbolInfo(op, Kind::Result, std::nullopt);
+    static SymbolInfo getResult(const Operator *op, int index) {
+      return SymbolInfo(op, Kind::Result,
+                        DagAndConstant(nullptr, index, std::nullopt));
     }
     static SymbolInfo getValue() {
       return SymbolInfo(nullptr, Kind::Value, std::nullopt);
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 1a1a58ad271bb..38725050cefe8 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -370,6 +370,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
   case Kind::Result: {
     // If `index` is greater than zero, then we are referencing a specific
     // result of a multi-result op. The result can still be variadic.
+    if (index < 0)
+      index = dagAndConstant->operandIndexOrNumValues;
     if (index >= 0) {
       std::string v =
           std::string(formatv("{0}.getODSResults({1})", name, index));
@@ -442,6 +444,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
     return std::string(repl);
   }
   case Kind::Result: {
+    if (index < 0)
+      index = dagAndConstant->operandIndexOrNumValues;
     if (index >= 0) {
       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
       LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
@@ -522,8 +526,9 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
 }
 
 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
-  std::string name = getValuePackName(symbol).str();
-  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
+  int index = -1;
+  StringRef name = getValuePackName(symbol, &index);
+  auto inserted = symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index));
 
   return symbolInfoMap.count(inserted->first) == 1;
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5564264ed8b0b..4bd68a5801bac 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1660,6 +1660,16 @@ def OneResultOp3 : TEST_Op<"one_result3"> {
   let results = (outs I32:$result1);
 }
 
+def OneResultOp4 : TEST_Op<"one_result4"> {
+  let arguments = (ins F32);
+  let results = (outs F32);
+}
+
+def TwoResultOp2 : TEST_Op<"two_result2"> {
+  let arguments = (ins);
+  let results = (outs F32, F32);
+}
+
 // Test using multi-result op as a whole
 def : Pat<(ThreeResultOp MultiResultOpKind1:$kind),
           (AnotherThreeResultOp $kind)>;
@@ -1696,6 +1706,12 @@ def : Pattern<
         (AnotherTwoResultOp $kind)
     ]>;
 
+// Test referencing a one-param op whose
+// param comes from the first result of a two-result op.
+def : Pat<
+  (OneResultOp4 (TwoResultOp2:$a__1)),
+  (replaceWithValue $a__0)>;
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Variadic Ops)
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index bd55338618eec..cedf528fb8717 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -594,6 +594,25 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar
   return
 }
 
+// CHECK-LABEL: @replaceOneResultWithNSuffixArgMatch
+func.func @replaceOneResultWithNSuffixArgMatch() -> (f32) {
+  // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
+  %0:2 = "test.two_result2"() : () -> (f32, f32)
+  %1 = "test.one_result4"(%0#1) : (f32) -> (f32)
+  // CHECK: return %0#0 : f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @replaceOneResultWithNSuffixArgNoMatch
+func.func @replaceOneResultWithNSuffixArgNoMatch() -> (f32) {
+  // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
+  %0:2 = "test.two_result2"() : () -> (f32, f32)
+  // CHECK: %1 = "test.one_result4"(%0#0) : (f32) -> f32
+  %1 = "test.one_result4"(%0#0) : (f32) -> (f32)
+  // CHECK: return %1 : f32
+  return %1 : f32
+}
+
 //===----------------------------------------------------------------------===//
 // Test patterns that operate on properties
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033daa719f..312f5174b7b9e 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -615,10 +615,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                            op.getQualCppClassName()));
 
   // If the operand's name is set, set to that variable.
-  auto name = tree.getSymbol();
+  int index = -1;
+  auto name = SymbolInfoMap::getValuePackName(tree.getSymbol(), &index).str();
   if (!name.empty())
     os << formatv("{0} = {1};\n", name, castedName);
 
+  if (index != -1) {
+    emitMatchCheck(opName,
+      formatv("(resultNumber{0} == 1)", depth),
+      formatv("\"{0} does not come from result number {1} type\"", castedName, index));
+  }
+
   for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
        ++i, ++opArgIdx) {
     auto opArg = op.getArg(opArgIdx);
@@ -662,6 +669,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
           "auto *{0} = "
           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
           argName, castedName, nextOperand);
+        os.indent() << formatv(
+          "[[maybe_unused]] auto resultNumber{0} = "
+          "::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin())).getResultNumber();\n",
+          depth + 1, castedName, nextOperand);
       // Null check of operand's definingOp
       emitMatchCheck(
           castedName, /*matchStr=*/argName,

Copy link

github-actions bot commented Sep 19, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/TableGen/Pattern.h mlir/lib/TableGen/Pattern.cpp mlir/tools/mlir-tblgen/RewriterGen.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 75cf7232e..ab49a60bd 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -621,9 +621,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     os << formatv("{0} = {1};\n", name, castedName);
 
   if (index != -1) {
-   emitMatchCheck(opName, formatv("(resultNumber{0} == 1)", depth),
-                  formatv("\"{0} does not come from result number {1} type\"",
-                          castedName, index));
+    emitMatchCheck(opName, formatv("(resultNumber{0} == 1)", depth),
+                   formatv("\"{0} does not come from result number {1} type\"",
+                           castedName, index));
   }
 
   for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;

@colltoaction
Copy link
Author

Any ideas why Windows is not passing at the moment? I see a dyn_cast error, but maybe it's not the one I added?

@joker-eph
Copy link
Collaborator

joker-eph commented Sep 19, 2025

The CI hadn't run yet, where did you see this error?
Edit: oh it ran on the the initial commit before you formatted.

@colltoaction
Copy link
Author

I missed a space in the formatter 🤦🏼. Anything I can check in the Windows build before I push?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants