Skip to content

Commit

Permalink
[MLIR] Remove TableGen redundant calls to native calls when creating …
Browse files Browse the repository at this point in the history
…new operations in DRR TableGen files

Summary:
Currently, the TableGen rewrite generates redundant native calls in MLIR DRR files. This is a problem as some native calls may involve significant computations (e.g. when performing constant propagation where every values in a large tensor is touched).

The pattern was as follow:

```c++
if (native-call(args)) tblgen_attrs.emplace_back(rewriter, attribute, native-call(args))
```

The replacement pattern compute `native-call(args)` once and then use it both in the `if` condition and the `emplace_back` call.

Differential Revision: https://reviews.llvm.org/D82101
  • Loading branch information
AlexandreEichenberger authored and jpienaar committed Jun 22, 2020
1 parent f633b07 commit 0164119
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
14 changes: 14 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -602,6 +602,20 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
def : Pat<(OpJ), (OpK)>;

// Test that natives calls are only called once during rewrites.
def OpM : TEST_Op<"op_m"> {
let arguments = (ins I32, OptionalAttr<I32Attr>:$optional_attr);
let results = (outs I32);
}
// Pattern add the argument plus a increasing static number hidden in
// OpMTest function. That value is set into the optional argument.
// That way, we will know if operations is called once or twice.
def OpMGetNullAttr : NativeCodeCall<"Attribute()">;
def OpMAttributeIsNull : Constraint<CPred<"! ($_self)">, "Attribute is null">;
def OpMVal : NativeCodeCall<"OpMTest($_builder, $0)">;
def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpMVal $attr) ),
[(OpMAttributeIsNull:$optAttr)]>;

// Test `$_` for ignoring op argument match.
def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> {
let arguments = (ins
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Expand Up @@ -32,6 +32,16 @@ static void handleNoResultOp(PatternRewriter &rewriter,
op.operand());
}

// Test that natives calls are only called once during rewrites.
// OpM_Test will return Pi, increased by 1 for each subsequent calls.
// This let us check the number of times OpM_Test was called by inspecting
// the returned value in the MLIR output.
static int64_t opMIncreasingValue = 314159265;
static Attribute OpMTest(PatternRewriter &rewriter, Value val) {
int64_t i = opMIncreasingValue++;
return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
}

namespace {
#include "TestPatterns.inc"
} // end anonymous namespace
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/mlir-tblgen/pattern.mlir
Expand Up @@ -359,3 +359,14 @@ func @generateVariadicOutputOpInNestedPattern() -> (i32) {
%0 = "test.one_i32_out"() : () -> (i32)
return %0 : i32
}

//===----------------------------------------------------------------------===//
// Test that natives calls are only called once during rewrites.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: redundantTest
func @redundantTest(%arg0: i32) -> i32 {
%0 = "test.op_m"(%arg0) : (i32) -> i32
// CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32
return %0 : i32
}
6 changes: 3 additions & 3 deletions mlir/tools/mlir-tblgen/RewriterGen.cpp
Expand Up @@ -1044,11 +1044,11 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
os.indent(6) << formatv(
"SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");

const char *addAttrCmd =
"if (auto tmpAttr = {1}) "
"tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n";
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
const char *addAttrCmd = "if ({1}) {{"
" tblgen_attrs.emplace_back(rewriter."
"getIdentifier(\"{0}\"), {1}); }\n";
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
if (auto subTree = node.getArgAsNestedDag(argIndex)) {
Expand Down

0 comments on commit 0164119

Please sign in to comment.