Skip to content

Commit

Permalink
Support variadic ops in declarative rewrite rules
Browse files Browse the repository at this point in the history
This CL extends declarative rewrite rules to support matching and
generating ops with variadic operands/results. For this, the
generated `matchAndRewrite()` method for each pattern now are
changed to

* Use "range" types for the local variables used to store captured
  values (`operand_range` for operands, `ArrayRef<Value *>` for
  values, *Op for results). This allows us to have a unified way
  of handling both single values and value ranges.
* Create local variables for each operand for op creation. If the
  operand is variadic, then a `SmallVector<Value*>` will be created
  to collect all values for that operand; otherwise a `Value*` will
  be created.
* Use a collective result type builder. All result types are
  specified via a single parameter to the builder.

We can use one result pattern to replace multiple results of the
matched root op. When that happens, it will require specifying
types for multiple results. Add a new collective-type builder.

PiperOrigin-RevId: 264588559
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Aug 21, 2019
1 parent 69cf811 commit 31cfee6
Show file tree
Hide file tree
Showing 6 changed files with 482 additions and 127 deletions.
29 changes: 25 additions & 4 deletions mlir/include/mlir/TableGen/Pattern.h
Expand Up @@ -262,8 +262,19 @@ class SymbolInfoMap {
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values). `name` is the
// name of the C++ variable that this symbol bounds to. `index` should only
// be used for indexing results.
std::string getValueAndRangeUse(StringRef name, int index) const;
// be used for indexing results. `fmt` is used to format each value.
// `separator` is used to separate values if this is a value range.
std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;

// Returns a string containing the C++ expression for referencing this
// symbol as a value range regardless of how many static values this symbol
// represents. `name` is the name of the C++ variable that this symbol
// bounds to. `index` should only be used for indexing results. `fmt` is
// used to format each value. `separator` is used to separate values in the
// range.
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;

const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
Expand Down Expand Up @@ -309,8 +320,18 @@ class SymbolInfoMap {

// Returns a string containing the C++ expression for referencing this
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values).
std::string getValueAndRangeUse(StringRef symbol) const;
// range (if this symbol represents multiple static values). `fmt` is used to
// format each value. `separator` is used to seperate values if `symbol`
// represents a value range.
std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;

// Returns a string containing the C++ expression for referencing this
// symbol as a value range regardless of how many static values this symbol
// represents. `fmt` is used to format each value. `seperator` is used to
// separate values in the range.
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;

// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on sucess. Returns `symbol`
Expand Down
103 changes: 87 additions & 16 deletions mlir/lib/TableGen/Pattern.cpp
Expand Up @@ -204,45 +204,99 @@ tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
return formatv("{0} {1};\n", type, name);
}
case Kind::Operand:
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
return formatv("Operation::operand_range {0}(op0->getOperands());\n", name);
}
case Kind::Value: {
return formatv("Value *{0};\n", name);
return formatv("ArrayRef<Value *> {0};\n", name);
}
case Kind::Result: {
// Use the op itself for the results.
// Use the op itself for captured results.
return formatv("{0} {1};\n", op->getQualCppClassName(), name);
}
}
llvm_unreachable("unknown kind");
}

std::string
tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name,
int index) const {
std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
switch (kind) {
case Kind::Attr: {
assert(index < 0);
return formatv(fmt, name);
}
case Kind::Operand: {
assert(index < 0);
auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariadic()) {
return formatv(fmt, name);
}
return formatv(fmt, formatv("(*{0}.begin())", name));
}
case Kind::Result: {
// If `index` is greater than zero, then we are referencing a specific
// result of a multi-result op. The result can still be variadic.
if (index >= 0) {
std::string v = formatv("{0}.getODSResults({1})", name, index);
if (!op->getResult(index).isVariadic())
v = formatv("(*{0}.begin())", v);
return formatv(fmt, v);
}

// We are referencing all results of the multi-result op. A specific result
// can either be a value or a range. Then join them with `separator`.
SmallVector<std::string, 4> values;
values.reserve(op->getNumResults());

for (int i = 0, e = op->getNumResults(); i < e; ++i) {
std::string v = formatv("{0}.getODSResults({1})", name, i);
if (!op->getResult(i).isVariadic()) {
v = formatv("(*{0}.begin())", v);
}
values.push_back(formatv(fmt, v));
}
return llvm::join(values, separator);
}
case Kind::Value: {
assert(index < 0);
assert(op == nullptr);
return formatv(fmt, name);
}
}
}

std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
switch (kind) {
case Kind::Attr:
case Kind::Operand: {
assert(index < 0 && "only allowed for symbol bound to result");
return name;
return formatv(fmt, name);
}
case Kind::Result: {
// TODO(b/133341698): The following is incorrect for variadic results. We
// should use getODSResults().
if (index >= 0) {
return formatv("{0}.getOperation()->getResult({1})", name, index);
return formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
}

// If referencing multiple results, compose a comma-separated list.
// We are referencing all results of the multi-result op. Each result should
// have a value range, and then join them with `separator`.
SmallVector<std::string, 4> values;
values.reserve(op->getNumResults());

for (int i = 0, e = op->getNumResults(); i < e; ++i) {
values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i));
values.push_back(
formatv(fmt, formatv("{0}.getODSResults({1})", name, i)));
}
return llvm::join(values, ", ");
return llvm::join(values, separator);
}
case Kind::Value: {
assert(index < 0 && "only allowed for symbol bound to result");
assert(op == nullptr);
return name;
return formatv(fmt, formatv("{{{0}}", name));
}
}
llvm_unreachable("unknown kind");
Expand Down Expand Up @@ -294,7 +348,24 @@ int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
return find(name)->getValue().getStaticValueCount();
}

std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
std::string
tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);

auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}

return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
}

std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol,
const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);

Expand All @@ -304,7 +375,7 @@ std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
PrintFatalError(loc, error);
}

return it->getValue().getValueAndRangeUse(name, index);
return it->getValue().getAllRangeUse(name, index, fmt, separator);
}

//===----------------------------------------------------------------------===//
Expand Down
110 changes: 110 additions & 0 deletions mlir/test/lib/TestDialect/TestOps.td
Expand Up @@ -505,6 +505,116 @@ def : Pattern<
(AnotherTwoResultOp MultiResultOpKind6)
]>;

//===----------------------------------------------------------------------===//
// Test Patterns (Variadic Ops)

def OneVResOneVOperandOp1 : TEST_Op<"one_variadic_out_one_variadic_in1"> {
let arguments = (ins Variadic<I32>:$inputs);
let results = (outs Variadic<I32>:$outputs);
}
def OneVResOneVOperandOp2 : TEST_Op<"one_variadic_out_one_variadic_in2"> {
let arguments = (ins Variadic<I32>:$inputs);
let results = (outs Variadic<I32>:$outputs);
}

// Rewrite an op with one variadic operand and one variadic result to
// another similiar op.
def : Pat<(OneVResOneVOperandOp1 $inputs), (OneVResOneVOperandOp2 $inputs)>;

def MixedVOperandOp1 : TEST_Op<"mixed_variadic_in1",
[SameVariadicOperandSize]> {
let arguments = (ins
Variadic<I32>:$input1,
F32:$input2,
Variadic<I32>:$input3
);
}

def MixedVOperandOp2 : TEST_Op<"mixed_variadic_in2",
[SameVariadicOperandSize]> {
let arguments = (ins
Variadic<I32>:$input1,
F32:$input2,
Variadic<I32>:$input3
);
}

// Rewrite an op with both variadic operands and normal operands.
def : Pat<(MixedVOperandOp1 $input1, $input2, $input3),
(MixedVOperandOp2 $input1, $input2, $input3)>;

def MixedVResultOp1 : TEST_Op<"mixed_variadic_out1", [SameVariadicResultSize]> {
let results = (outs
Variadic<I32>:$output1,
F32:$output2,
Variadic<I32>:$output3
);
}

def MixedVResultOp2 : TEST_Op<"mixed_variadic_out2", [SameVariadicResultSize]> {
let results = (outs
Variadic<I32>:$output1,
F32:$output2,
Variadic<I32>:$output3
);
}

// Rewrite an op with both variadic results and normal results.
// Note that because we are generating the op with a top-level result pattern,
// we are able to deduce the correct result types for the generated op using
// the information from the matched root op.
def : Pat<(MixedVResultOp1), (MixedVResultOp2)>;

def OneI32ResultOp : TEST_Op<"one_i32_out"> {
let results = (outs I32:$output);
}

def MixedVOperandOp3 : TEST_Op<"mixed_variadic_in3",
[SameVariadicOperandSize]> {
let arguments = (ins
I32:$input1,
Variadic<I32>:$input2,
Variadic<I32>:$input3,
I32Attr:$count
);

let results = (outs I32:$output);
}

def MixedVResultOp3 : TEST_Op<"mixed_variadic_out3",
[SameVariadicResultSize]> {
let arguments = (ins I32Attr:$count);

let results = (outs
I32:$output1,
Variadic<I32>:$output2,
Variadic<I32>:$output3
);

// We will use this op in a nested result pattern, where we cannot deduce the
// result type. So need to provide a builder not requiring result types.
let builders = [
OpBuilder<
"Builder *builder, OperationState *state, IntegerAttr count",
[{
auto i32Type = builder->getIntegerType(32);
state->addTypes(i32Type); // $ouput1
SmallVector<Type, 4> types(count.getInt(), i32Type);
state->addTypes(types); // $ouput2
state->addTypes(types); // $ouput3
state->addAttribute("count", count);
}]>
];
}

// Generates an op with variadic results using nested pattern.
def : Pat<(OneI32ResultOp),
(MixedVOperandOp3
(MixedVResultOp3:$results__0 ConstantAttr<I32Attr, "2">),
(replaceWithValue $results__1),
(replaceWithValue $results__2),
ConstantAttr<I32Attr, "2">)>;

//===----------------------------------------------------------------------===//
// Test Legalization
//===----------------------------------------------------------------------===//
Expand Down
52 changes: 52 additions & 0 deletions mlir/test/mlir-tblgen/pattern.mlir
Expand Up @@ -215,3 +215,55 @@ func @useAuxiliaryOpToReplaceMultiResultOp() -> (i32, f32, f32) {
%0:3 = "test.three_result"() {kind = 6} : () -> (i32, f32, f32)
return %0#0, %0#1, %0#2 : i32, f32, f32
}

//===----------------------------------------------------------------------===//
// Test Multi-result Ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @replaceOneVariadicOutOneVariadicInOp
func @replaceOneVariadicOutOneVariadicInOp(%arg0: i32, %arg1: i32, %arg2: i32) -> (i32, i32, i32, i32, i32, i32) {
// CHECK: %[[cnt1:.*]] = "test.one_variadic_out_one_variadic_in2"(%arg0)
// CHECK: %[[cnt2:.*]]:2 = "test.one_variadic_out_one_variadic_in2"(%arg0, %arg1)
// CHECK: %[[cnt3:.*]]:3 = "test.one_variadic_out_one_variadic_in2"(%arg0, %arg1, %arg2)
// CHECK: return %[[cnt1]], %[[cnt2]]#0, %[[cnt2]]#1, %[[cnt3]]#0, %[[cnt3]]#1, %[[cnt3]]#2

%0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> (i32)
%1:2 = "test.one_variadic_out_one_variadic_in1"(%arg0, %arg1) : (i32, i32) -> (i32, i32)
%2:3 = "test.one_variadic_out_one_variadic_in1"(%arg0, %arg1, %arg2) : (i32, i32, i32) -> (i32, i32, i32)
return %0, %1#0, %1#1, %2#0, %2#1, %2#2 : i32, i32, i32, i32, i32, i32
}

// CHECK-LABEL: @replaceMixedVariadicInputOp
func @replaceMixedVariadicInputOp(%arg0: i32, %arg1: f32, %arg2: i32) -> () {
// CHECK: "test.mixed_variadic_in2"(%arg1)
// CHECK: "test.mixed_variadic_in2"(%arg0, %arg1, %arg2)
// CHECK: "test.mixed_variadic_in2"(%arg0, %arg0, %arg1, %arg2, %arg2)

"test.mixed_variadic_in1"(%arg1) : (f32) -> ()
"test.mixed_variadic_in1"(%arg0, %arg1, %arg2) : (i32, f32, i32) -> ()
"test.mixed_variadic_in1"(%arg0, %arg0, %arg1, %arg2, %arg2) : (i32, i32, f32, i32, i32) -> ()
return
}

// CHECK-LABEL: @replaceMixedVariadicOutputOp
func @replaceMixedVariadicOutputOp() -> (f32, i32, f32, i32, i32, i32, f32, i32, i32) {
// CHECK: %[[cnt1:.*]] = "test.mixed_variadic_out2"()
// CHECK: %[[cnt3:.*]]:3 = "test.mixed_variadic_out2"()
// CHECK: %[[cnt5:.*]]:5 = "test.mixed_variadic_out2"()
// CHECK: return %[[cnt1]], %[[cnt3]]#0, %[[cnt3]]#1, %[[cnt3]]#2, %[[cnt5]]#0, %[[cnt5]]#1, %[[cnt5]]#2, %[[cnt5]]#3, %[[cnt5]]#4

%0 = "test.mixed_variadic_out1"() : () -> (f32)
%1:3 = "test.mixed_variadic_out1"() : () -> (i32, f32, i32)
%2:5 = "test.mixed_variadic_out1"() : () -> (i32, i32, f32, i32, i32)
return %0, %1#0, %1#1, %1#2, %2#0, %2#1, %2#2, %2#3, %2#4 : f32, i32, f32, i32, i32, i32, f32, i32, i32
}

// CHECK-LABEL: @generateVaridicOutputOpInNestedPattern
func @generateVaridicOutputOpInNestedPattern() -> (i32) {
// CHECK: %[[cnt5:.*]]:5 = "test.mixed_variadic_out3"()
// CHECK: %[[res:.*]] = "test.mixed_variadic_in3"(%[[cnt5]]#0, %[[cnt5]]#1, %[[cnt5]]#2, %[[cnt5]]#3, %[[cnt5]]#4)
// CHECK: return %[[res]]

%0 = "test.one_i32_out"() : () -> (i32)
return %0 : i32
}

0 comments on commit 31cfee6

Please sign in to comment.