Skip to content

Commit

Permalink
[mlir-tblgen] Slightly improve the diagnostic message in pattern match
Browse files Browse the repository at this point in the history
Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105883
  • Loading branch information
ChiaHungDuan committed Jul 19, 2021
1 parent 6601be4 commit 9bdf1ab
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions mlir/tools/mlir-tblgen/RewriterGen.cpp
Expand Up @@ -255,7 +255,6 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,

raw_indented_ostream::DelimitedScope scope(os);

os << "if(!" << opName << ") return ::mlir::failure();\n";
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
std::string argName = formatv("arg{0}_{1}", depth, i);
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
Expand All @@ -277,15 +276,15 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
std::tie(hasLocationDirective, locToUse) = getLocation(tree);

auto fmt = tree.getNativeCodeTemplate();
if (fmt.count("$_self") != 1) {
if (fmt.count("$_self") != 1)
PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
"passing the defining Operation");
}

auto nativeCodeCall = std::string(tgfmt(
fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture));

os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
formatv("\"{0} return failure\"", nativeCodeCall));

for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto name = tree.getArgName(i);
Expand Down Expand Up @@ -338,20 +337,21 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
<< '\n');

std::string castedName = formatv("castedOp{0}", depth);
os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
"(void){0};\n",
castedName, opName, op.getQualCppClassName());

// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
os << formatv("if (!{0}) return ::mlir::failure();\n", castedName);
}
if (tree.getNumArgs() != op.getNumArgs()) {
if (depth != 0)
emitMatchCheck(opName, /*matchStr=*/castedName,
formatv("\"{0} is not {1} type\"", castedName,
op.getQualCppClassName()));

if (tree.getNumArgs() != op.getNumArgs())
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
"pattern vs. {2} in definition",
op.getOperationName(), tree.getNumArgs(),
op.getNumArgs()));
}

// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
Expand Down Expand Up @@ -379,7 +379,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
os.indent() << formatv(
"auto *{0} = "
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
argName, castedName, nextOperand++);
argName, castedName, nextOperand);
// Null check of operand's definingOp
emitMatchCheck(castedName, /*matchStr=*/argName,
formatv("\"Operand {0} of {1} has null definingOp\"",
nextOperand++, castedName));
emitMatch(argTree, argName, depth + 1);
os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
os.unindent() << "}\n";
Expand Down

0 comments on commit 9bdf1ab

Please sign in to comment.