Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> {
let results = (outs I32);
}

def OpQ : TEST_Op<"op_q"> {
let arguments = (ins AnyType, AnyType);
let results = (outs AnyType);
}

// Test constant-folding a pattern that maps `(F32) -> SI32`.
def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
let arguments = (ins RankedTensorOf<[F32]>:$operand);
Expand Down Expand Up @@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern :
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;

// Test equal arguments checks are applied before user provided constraints.
def AssertBinOpEqualArgsAndReturnTrue : Constraint<
CPred<"assertBinOpEqualArgsAndReturnTrue($0)">>;
def TestEqualArgsCheckBeforeUserConstraintsPattern :
Pat<(OpQ:$op $x, $x),
(replaceWithValue $x),
[(AssertBinOpEqualArgsAndReturnTrue $op)]>;

// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) {
return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
}

static bool assertBinOpEqualArgsAndReturnTrue(Value v) {
Operation *operation = v.getDefiningOp();
if (operation->getOperand(0) != operation->getOperand(1)) {
// Name binding equality check must happen before user-defined constraints,
// thus this must not be triggered.
llvm::report_fatal_error("Arguments are not equal");
}
return true;
}

namespace {
#include "TestPatterns.inc"
} // namespace
Expand Down
28 changes: 23 additions & 5 deletions mlir/test/mlir-tblgen/pattern.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs(
// def TestNestedOpEqualArgsPattern :
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;

// CHECK: %arg1
// CHECK: "test.op_o"(%arg1)
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
%2 = "test.op_o"(%1) : (i32) -> (i32)

// CHECK: test.op_p
// CHECK: test.op_n
%2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
// CHECK-NEXT: %[[P:.*]] = "test.op_p"
// CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]])
// CHECK-NEXT: "test.op_o"(%[[N]])
%3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
%4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32)
%5 = "test.op_o"(%4) : (i32) -> (i32)

return
}
Expand Down Expand Up @@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs(
return
}

func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) {
// def TestEqualArgsCheckBeforeUserConstraintsPattern :
// Pat<(OpQ:$op $x, $x),
// (replaceWithValue $x),
// [(AssertBinOpEqualArgsAndReturnTrue $op)]>;

// CHECK: "test.op_q"(%arg0, %arg1)
%0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32)

// CHECK: "test.op_q"(%arg1, %arg0)
%1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32)

return
}

//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 26 additions & 23 deletions mlir/tools/mlir-tblgen/RewriterGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
int depth = 0;
emitMatch(tree, opName, depth);

// Some of the operands could be bound to the same symbol name, we need
// to enforce equality constraint on those.
// This has to happen before user provided constraints, which may assume the
// same name checks are already performed, since in the pattern source code
// the user provided constraints appear later.
// TODO: we should be able to emit equality checks early
// and short circuit unnecessary work if vars are not equal.
for (auto symbolInfoIt = symbolInfoMap.begin();
symbolInfoIt != symbolInfoMap.end();) {
auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
auto startRange = range.first;
auto endRange = range.second;

auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
for (++startRange; startRange != endRange; ++startRange) {
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
emitMatchCheck(
opName,
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
secondOperand));
}

symbolInfoIt = endRange;
}

for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint;
auto &entities = appliedConstraint.entities;
Expand Down Expand Up @@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
}
}

// Some of the operands could be bound to the same symbol name, we need
// to enforce equality constraint on those.
// TODO: we should be able to emit equality checks early
// and short circuit unnecessary work if vars are not equal.
for (auto symbolInfoIt = symbolInfoMap.begin();
symbolInfoIt != symbolInfoMap.end();) {
auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
auto startRange = range.first;
auto endRange = range.second;

auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
for (++startRange; startRange != endRange; ++startRange) {
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
emitMatchCheck(
opName,
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
secondOperand));
}

symbolInfoIt = endRange;
}

LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}

Expand Down