diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index e0112af6b5b037..611bc5c1c05eaa 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -255,7 +255,6 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, raw_indented_ostream::DelimitedScope scope(os); - os << "if(!" << opName << ") return ::mlir::failure();\n"; for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { std::string argName = formatv("arg{0}_{1}", depth, i); if (DagNode argTree = tree.getArgAsNestedDag(i)) { @@ -277,15 +276,15 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, std::tie(hasLocationDirective, locToUse) = getLocation(tree); auto fmt = tree.getNativeCodeTemplate(); - if (fmt.count("$_self") != 1) { + if (fmt.count("$_self") != 1) PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " "passing the defining Operation"); - } auto nativeCodeCall = std::string(tgfmt( fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture)); - os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n"; + emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall), + formatv("\"{0} return failure\"", nativeCodeCall)); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto name = tree.getArgName(i); @@ -338,20 +337,21 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { << '\n'); std::string castedName = formatv("castedOp{0}", depth); - os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); " + os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); " "(void){0};\n", castedName, opName, op.getQualCppClassName()); + // Skip the operand matching at depth 0 as the pattern rewriter already does. - if (depth != 0) { - // Skip if there is no defining operation (e.g., arguments to function). - os << formatv("if (!{0}) return ::mlir::failure();\n", castedName); - } - if (tree.getNumArgs() != op.getNumArgs()) { + if (depth != 0) + emitMatchCheck(opName, /*matchStr=*/castedName, + formatv("\"{0} is not {1} type\"", castedName, + op.getQualCppClassName())); + + if (tree.getNumArgs() != op.getNumArgs()) PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " "pattern vs. {2} in definition", op.getOperationName(), tree.getNumArgs(), op.getNumArgs())); - } // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); @@ -379,7 +379,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { os.indent() << formatv( "auto *{0} = " "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", - argName, castedName, nextOperand++); + argName, castedName, nextOperand); + // Null check of operand's definingOp + emitMatchCheck(castedName, /*matchStr=*/argName, + formatv("\"Operand {0} of {1} has null definingOp\"", + nextOperand++, castedName)); emitMatch(argTree, argName, depth + 1); os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); os.unindent() << "}\n";