Skip to content

Conversation

wecing
Copy link
Contributor

@wecing wecing commented Oct 8, 2025

For a pattern like this:

Pat<(MyOp $x, $x),
    (...),
    [(MyCheck $x)]>;

The old implementation generates:

Pat<(MyOp $x0, $x1),
    (...),
    [(MyCheck $x0),
     ($x0 == $x1)]>;

This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching.

This commit moves the equality checks before the other constraints, i.e.:

Pat<(MyOp $x0, $x1),
    (...),
    [($x0 == $x1),
     (MyCheck $x0)]>;

For a pattern like this:

    Pat<(MyOp $x, $x),
        (...),
        [(MyCheck $x)]>;

The old implementation generates:

    Pat<(MyOp $x0, $x1),
        (...),
        [(MyCheck $x0),
         ($x0 == $x1)]>;

This is not very straightforward, because the $x name appears in the
source pattern; it's attempting to assume equality check will be
performed as part of the source pattern matching.

This commit moves the equality checks before the other constraints, i.e.:

    Pat<(MyOp $x0, $x1),
        (...),
        [($x0 == $x1),
         (MyCheck $x0)]>;
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Chenguang Wang (wecing)

Changes

For a pattern like this:

Pat&lt;(MyOp $x, $x),
    (...),
    [(MyCheck $x)]&gt;;

The old implementation generates:

Pat&lt;(MyOp $x0, $x1),
    (...),
    [(MyCheck $x0),
     ($x0 == $x1)]&gt;;

This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching.

This commit moves the equality checks before the other constraints, i.e.:

Pat&lt;(MyOp $x0, $x1),
    (...),
    [($x0 == $x1),
     (MyCheck $x0)]&gt;;

Full diff: https://github.com/llvm/llvm-project/pull/162526.diff

4 Files Affected:

  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+13)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+5)
  • (modified) mlir/test/mlir-tblgen/pattern.mlir (+23-5)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+26-23)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6ea27187655ee..ed62bee3bc152 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> {
   let results = (outs I32);
 }
 
+def OpQ : TEST_Op<"op_q"> {
+  let arguments = (ins AnyType, AnyType);
+  let results = (outs AnyType);
+}
+
 // Test constant-folding a pattern that maps `(F32) -> SI32`.
 def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
   let arguments = (ins RankedTensorOf<[F32]>:$operand);
@@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern :
 def TestMultipleEqualArgsPattern :
   Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
 
+// Test equal arguments checks are applied before user provided constraints.
+// CheckIntIs32Bits would throw exceptions if input is not i32.
+def CheckIntIs32Bits : Constraint<CPred<"intIs32Bits($0)">>;
+def TestEqualArgsCheckBeforeUserConstraintsPattern :
+  Pat<(OpQ $x, $x),
+      (replaceWithValue $x),
+      [(CheckIntIs32Bits $x)]>;
+
 // Test for memrefs normalization of an op with normalizable memrefs.
 def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
   let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f8b5144e3acb2..d764deb023873 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -70,6 +70,11 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) {
   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
 }
 
+// Requires input value is of i32 type.
+static bool intIs32Bits(Value v) {
+  return mlir::dyn_cast<IntegerType>(v.getType()).getWidth() == 32;
+}
+
 namespace {
 #include "TestPatterns.inc"
 } // namespace
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index bd55338618eec..a67830373e701 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs(
   // def TestNestedOpEqualArgsPattern :
   //   Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
 
-  // CHECK: %arg1
+  // CHECK: "test.op_o"(%arg1)
   %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
     : (i32, i32, i32, i32, i32, i32) -> (i32)
   %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
+  %2 = "test.op_o"(%1) : (i32) -> (i32)
 
-  // CHECK: test.op_p
-  // CHECK: test.op_n
-  %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+  // CHECK-NEXT: %[[P:.*]] = "test.op_p"
+  // CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]])
+  // CHECK-NEXT: "test.op_o"(%[[N]])
+  %3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
     : (i32, i32, i32, i32, i32, i32) -> (i32)
-  %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
+  %4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32)
+  %5 = "test.op_o"(%4) : (i32) -> (i32)
 
   return
 }
@@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs(
   return
 }
 
+func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) {
+  // def TestEqualArgsCheckBeforeUserConstraintsPattern :
+  //   Pat<(OpQ $x, $x),
+  //       [(CheckIntIs32Bits $x)],
+  //       (replaceWithValue $x)>;
+
+  // CHECK: "test.op_q"(%arg0, %arg1)
+  %0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32)
+
+  // CHECK: "test.op_q"(%arg1, %arg0)
+  %1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32)
+
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test Symbol Binding
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033daa719f..40bc1a9c3868c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
   int depth = 0;
   emitMatch(tree, opName, depth);
 
+  // Some of the operands could be bound to the same symbol name, we need
+  // to enforce equality constraint on those.
+  // This has to happen before user provided constraints, which may assume the
+  // same name checks are already performed, since in the pattern source code
+  // the user provided constraints appear later.
+  // TODO: we should be able to emit equality checks early
+  // and short circuit unnecessary work if vars are not equal.
+  for (auto symbolInfoIt = symbolInfoMap.begin();
+       symbolInfoIt != symbolInfoMap.end();) {
+    auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
+    auto startRange = range.first;
+    auto endRange = range.second;
+
+    auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
+    for (++startRange; startRange != endRange; ++startRange) {
+      auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
+      emitMatchCheck(
+          opName,
+          formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
+          formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
+                  secondOperand));
+    }
+
+    symbolInfoIt = endRange;
+  }
+
   for (auto &appliedConstraint : pattern.getConstraints()) {
     auto &constraint = appliedConstraint.constraint;
     auto &entities = appliedConstraint.entities;
@@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
     }
   }
 
-  // Some of the operands could be bound to the same symbol name, we need
-  // to enforce equality constraint on those.
-  // TODO: we should be able to emit equality checks early
-  // and short circuit unnecessary work if vars are not equal.
-  for (auto symbolInfoIt = symbolInfoMap.begin();
-       symbolInfoIt != symbolInfoMap.end();) {
-    auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
-    auto startRange = range.first;
-    auto endRange = range.second;
-
-    auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
-    for (++startRange; startRange != endRange; ++startRange) {
-      auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
-      emitMatchCheck(
-          opName,
-          formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
-          formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
-                  secondOperand));
-    }
-
-    symbolInfoIt = endRange;
-  }
-
   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
 }
 

Copy link

github-actions bot commented Oct 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants