-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Execute same operand name constraints before user constraints. #162526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
For a pattern like this: Pat<(MyOp $x, $x), (...), [(MyCheck $x)]>; The old implementation generates: Pat<(MyOp $x0, $x1), (...), [(MyCheck $x0), ($x0 == $x1)]>; This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching. This commit moves the equality checks before the other constraints, i.e.: Pat<(MyOp $x0, $x1), (...), [($x0 == $x1), (MyCheck $x0)]>;
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Chenguang Wang (wecing) ChangesFor a pattern like this:
The old implementation generates:
This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching. This commit moves the equality checks before the other constraints, i.e.:
Full diff: https://github.com/llvm/llvm-project/pull/162526.diff 4 Files Affected:
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6ea27187655ee..ed62bee3bc152 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> {
let results = (outs I32);
}
+def OpQ : TEST_Op<"op_q"> {
+ let arguments = (ins AnyType, AnyType);
+ let results = (outs AnyType);
+}
+
// Test constant-folding a pattern that maps `(F32) -> SI32`.
def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
let arguments = (ins RankedTensorOf<[F32]>:$operand);
@@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern :
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
+// Test equal arguments checks are applied before user provided constraints.
+// CheckIntIs32Bits would throw exceptions if input is not i32.
+def CheckIntIs32Bits : Constraint<CPred<"intIs32Bits($0)">>;
+def TestEqualArgsCheckBeforeUserConstraintsPattern :
+ Pat<(OpQ $x, $x),
+ (replaceWithValue $x),
+ [(CheckIntIs32Bits $x)]>;
+
// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f8b5144e3acb2..d764deb023873 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -70,6 +70,11 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) {
return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
}
+// Requires input value is of i32 type.
+static bool intIs32Bits(Value v) {
+ return mlir::dyn_cast<IntegerType>(v.getType()).getWidth() == 32;
+}
+
namespace {
#include "TestPatterns.inc"
} // namespace
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index bd55338618eec..a67830373e701 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs(
// def TestNestedOpEqualArgsPattern :
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
- // CHECK: %arg1
+ // CHECK: "test.op_o"(%arg1)
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
+ %2 = "test.op_o"(%1) : (i32) -> (i32)
- // CHECK: test.op_p
- // CHECK: test.op_n
- %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+ // CHECK-NEXT: %[[P:.*]] = "test.op_p"
+ // CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]])
+ // CHECK-NEXT: "test.op_o"(%[[N]])
+ %3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
- %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
+ %4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32)
+ %5 = "test.op_o"(%4) : (i32) -> (i32)
return
}
@@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs(
return
}
+func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) {
+ // def TestEqualArgsCheckBeforeUserConstraintsPattern :
+ // Pat<(OpQ $x, $x),
+ // [(CheckIntIs32Bits $x)],
+ // (replaceWithValue $x)>;
+
+ // CHECK: "test.op_q"(%arg0, %arg1)
+ %0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32)
+
+ // CHECK: "test.op_q"(%arg1, %arg0)
+ %1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32)
+
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033daa719f..40bc1a9c3868c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
int depth = 0;
emitMatch(tree, opName, depth);
+ // Some of the operands could be bound to the same symbol name, we need
+ // to enforce equality constraint on those.
+ // This has to happen before user provided constraints, which may assume the
+ // same name checks are already performed, since in the pattern source code
+ // the user provided constraints appear later.
+ // TODO: we should be able to emit equality checks early
+ // and short circuit unnecessary work if vars are not equal.
+ for (auto symbolInfoIt = symbolInfoMap.begin();
+ symbolInfoIt != symbolInfoMap.end();) {
+ auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
+ auto startRange = range.first;
+ auto endRange = range.second;
+
+ auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
+ for (++startRange; startRange != endRange; ++startRange) {
+ auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
+ emitMatchCheck(
+ opName,
+ formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
+ formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
+ secondOperand));
+ }
+
+ symbolInfoIt = endRange;
+ }
+
for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint;
auto &entities = appliedConstraint.entities;
@@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
}
}
- // Some of the operands could be bound to the same symbol name, we need
- // to enforce equality constraint on those.
- // TODO: we should be able to emit equality checks early
- // and short circuit unnecessary work if vars are not equal.
- for (auto symbolInfoIt = symbolInfoMap.begin();
- symbolInfoIt != symbolInfoMap.end();) {
- auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
- auto startRange = range.first;
- auto endRange = range.second;
-
- auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
- for (++startRange; startRange != endRange; ++startRange) {
- auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
- emitMatchCheck(
- opName,
- formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
- formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
- secondOperand));
- }
-
- symbolInfoIt = endRange;
- }
-
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
For a pattern like this:
The old implementation generates:
This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching.
This commit moves the equality checks before the other constraints, i.e.: