Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/TableGen/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,9 @@ class SymbolInfoMap {
DagAndConstant(node.getAsOpaquePointer(), operandIndex,
variadicSubIndex));
}
static SymbolInfo getResult(const Operator *op) {
return SymbolInfo(op, Kind::Result, std::nullopt);
static SymbolInfo getResult(const Operator *op, int index) {
return SymbolInfo(op, Kind::Result,
DagAndConstant(nullptr, index, std::nullopt));
}
static SymbolInfo getValue() {
return SymbolInfo(nullptr, Kind::Value, std::nullopt);
Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/TableGen/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
case Kind::Result: {
// If `index` is greater than zero, then we are referencing a specific
// result of a multi-result op. The result can still be variadic.
if (index < 0)
index = dagAndConstant->operandIndexOrNumValues;
if (index >= 0) {
std::string v =
std::string(formatv("{0}.getODSResults({1})", name, index));
Expand Down Expand Up @@ -442,6 +444,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
return std::string(repl);
}
case Kind::Result: {
if (index < 0)
index = dagAndConstant->operandIndexOrNumValues;
if (index >= 0) {
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
Expand Down Expand Up @@ -522,8 +526,10 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
}

bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
std::string name = getValuePackName(symbol).str();
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
int index = -1;
StringRef name = getValuePackName(symbol, &index);
auto inserted =
symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index));

return symbolInfoMap.count(inserted->first) == 1;
}
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,16 @@ def OneResultOp3 : TEST_Op<"one_result3"> {
let results = (outs I32:$result1);
}

def OneResultOp4 : TEST_Op<"one_result4"> {
let arguments = (ins F32);
let results = (outs F32);
}

def TwoResultOp2 : TEST_Op<"two_result2"> {
let arguments = (ins);
let results = (outs F32, F32);
}

// Test using multi-result op as a whole
def : Pat<(ThreeResultOp MultiResultOpKind1:$kind),
(AnotherThreeResultOp $kind)>;
Expand Down Expand Up @@ -1696,6 +1706,12 @@ def : Pattern<
(AnotherTwoResultOp $kind)
]>;

// Test referencing a one-param op whose
// param comes from the first result of a two-result op.
def : Pat<
(OneResultOp4 (TwoResultOp2:$a__1)),
(replaceWithValue $a__0)>;

//===----------------------------------------------------------------------===//
// Test Patterns (Variadic Ops)
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/mlir-tblgen/pattern.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,25 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar
return
}

// CHECK-LABEL: @replaceOneResultWithNSuffixArgMatch
func.func @replaceOneResultWithNSuffixArgMatch() -> (f32) {
// CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
%0:2 = "test.two_result2"() : () -> (f32, f32)
%1 = "test.one_result4"(%0#1) : (f32) -> (f32)
// CHECK: return %0#0 : f32
return %1 : f32
}

// CHECK-LABEL: @replaceOneResultWithNSuffixArgNoMatch
func.func @replaceOneResultWithNSuffixArgNoMatch() -> (f32) {
// CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
%0:2 = "test.two_result2"() : () -> (f32, f32)
// CHECK: %1 = "test.one_result4"(%0#0) : (f32) -> f32
%1 = "test.one_result4"(%0#0) : (f32) -> (f32)
// CHECK: return %1 : f32
return %1 : f32
}

//===----------------------------------------------------------------------===//
// Test patterns that operate on properties
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 13 additions & 1 deletion mlir/tools/mlir-tblgen/RewriterGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
op.getQualCppClassName()));

// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
int index = -1;
auto name = SymbolInfoMap::getValuePackName(tree.getSymbol(), &index).str();
if (!name.empty())
os << formatv("{0} = {1};\n", name, castedName);

if (index != -1) {
emitMatchCheck(opName, formatv("(resultNumber{0} == 1)", depth),
formatv("\"{0} does not come from result number {1} type\"",
castedName, index));
}

for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
++i, ++opArgIdx) {
auto opArg = op.getArg(opArgIdx);
Expand Down Expand Up @@ -662,6 +669,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
"auto *{0} = "
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
argName, castedName, nextOperand);
os.indent() << formatv(
"[[maybe_unused]] auto resultNumber{0} = "
"::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin()"
")).getResultNumber();\n",
depth + 1, castedName, nextOperand);
// Null check of operand's definingOp
emitMatchCheck(
castedName, /*matchStr=*/argName,
Expand Down
Loading