Skip to content

Commit 08d7377

Browse files
committed
[mlir] Enable DRR variadic operand matching
This commit enables DRR rewriter to match a fixed number of sub-operands as a variadic operand. Differential Review: https://reviews.llvm.org/D157359
1 parent 8e946fe commit 08d7377

File tree

7 files changed

+459
-34
lines changed

7 files changed

+459
-34
lines changed

mlir/docs/DeclarativeRewrites.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,50 @@ correspond to multiple actual values.
647647

648648
[TODO]
649649

650+
#### Match variadic operand
651+
652+
Use the `variadic` DAG node to match a variadic operand with a fixed number of
653+
actual sub-operands.
654+
655+
For example, assume that `ConcatenateOp` is an operation with a variadic
656+
operand:
657+
658+
```tablegen
659+
def ConcatenateOp : TEST_Op<"concatenate"> {
660+
let arguments = (ins
661+
Variadic<AnyTensor>:$inputs,
662+
I32Attr:$axis
663+
);
664+
665+
let results = (outs
666+
AnyTensor$output
667+
);
668+
}
669+
```
670+
671+
We can match `ConcatenateOp` with exactly 2 actual operands with:
672+
673+
```tablegen
674+
def : Pat<(ConcatenateOp (variadic $input0, $input1), $axis),
675+
...>;
676+
```
677+
678+
The variadic sub-operands can be sub-DAGs to be matched:
679+
680+
```tablegen
681+
def : Pat<(ConcatenateOp (variadic (SomeOp $a), (AnotherOp $b, $c)), $axis),
682+
(OtherOp $a, $b, $c)>;
683+
```
684+
685+
The variadic DAG can be bound to a symbol, which refers to the full
686+
`operand_range`:
687+
688+
```tablegen
689+
def : Pat<(ConcatenateOp (variadic:$inputs $input0, $input1),
690+
ConstantAttr<I32Attr, "0">),
691+
(VStackOp $inputs)>;
692+
```
693+
650694
### Supplying additional constraints
651695

652696
Constraints can be placed on op arguments when matching. But sometimes we need

mlir/include/mlir/IR/PatternBase.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,22 @@ def returnType;
218218
// `either` while pattern matching.
219219
def either;
220220

221+
// Directive used to match variadic operands. This directive only matches if
222+
// the variadic operand has the same length as the specified formal
223+
// sub-dags.
224+
//
225+
// ```
226+
// (VariadicOp (variadic:$input1 $input1a, $input1b),
227+
// (variadic:$input2 $input2a, $input2b, $input2c),
228+
// $attr1, $attr2)
229+
// ```
230+
//
231+
// The pattern above only matches if the `$input1` operand is of length 2,
232+
// `$input2` is of length 3, and all sub-dags match respectively. The `$input1`
233+
// symbol denotes the full variadic operand range. The `$input1a` symbol
234+
// denotes the first operand in the variadic sub-operands.
235+
def variadic;
236+
221237
//===----------------------------------------------------------------------===//
222238
// Common value constraints
223239
//===----------------------------------------------------------------------===//

mlir/include/mlir/TableGen/Pattern.h

Lines changed: 115 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/StringMap.h"
2323
#include "llvm/ADT/StringSet.h"
2424

25+
#include <optional>
2526
#include <unordered_map>
2627

2728
namespace llvm {
@@ -189,6 +190,9 @@ class DagNode {
189190
// Returns whether this DAG is an `either` specifier.
190191
bool isEither() const;
191192

193+
// Returns whether this DAG is an `variadic` specifier.
194+
bool isVariadic() const;
195+
192196
// Returns true if this DAG node is an operation.
193197
bool isOperation() const;
194198

@@ -268,9 +272,94 @@ class SymbolInfoMap {
268272
// Allow SymbolInfoMap to access private methods.
269273
friend class SymbolInfoMap;
270274

271-
// DagNode and DagLeaf are accessed by value which means it can't be used as
272-
// identifier here. Use an opaque pointer type instead.
273-
using DagAndConstant = std::pair<const void *, int>;
275+
// Structure to uniquely distinguish different locations of the symbols.
276+
//
277+
// * If a symbol is defined as an operand of an operation, `dag` specifies
278+
// the DAG of the operation, `operandIndexOrNumValues` specifies the
279+
// operand index, and `variadicSubIndex` must be set to `std::nullopt`.
280+
//
281+
// * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG
282+
// of the parent operation, `operandIndexOrNumValues` specifies the
283+
// declared operand index of the variadic operand in the parent
284+
// operation.
285+
//
286+
// - If the symbol is defined as a result of `variadic` DAG, the
287+
// `variadicSubIndex` must be set to `std::nullopt`, which means that
288+
// the symbol binds to the full operand range.
289+
//
290+
// - If the symbol is defined as a operand, the `variadicSubIndex` must
291+
// be set to the index within the variadic sub-operand list.
292+
//
293+
// * If a symbol is defined in a `either` DAG, `dag` specifies the DAG
294+
// of the parent operation, `operandIndexOrNumValues` specifies the
295+
// operand index in the parent operation (not necessary the index in the
296+
// DAG).
297+
//
298+
// * If a symbol is defined as a result, specifies the number of returning
299+
// value.
300+
//
301+
// Example 1:
302+
//
303+
// def : Pat<(OpA $input0, $input1), ...>;
304+
//
305+
// $input0: (OpA, 0, nullopt)
306+
// $input1: (OpA, 1, nullopt)
307+
//
308+
// Example 2:
309+
//
310+
// def : Pat<(OpB (variadic:$input0 $input0a, $input0b),
311+
// (variadic:$input1 $input1a, $input1b, $input1c)),
312+
// ...>;
313+
//
314+
// $input0: (OpB, 0, nullopt)
315+
// $input0a: (OpB, 0, 0)
316+
// $input0b: (OpB, 0, 1)
317+
// $input1: (OpB, 1, nullopt)
318+
// $input1a: (OpB, 1, 0)
319+
// $input1b: (OpB, 1, 1)
320+
// $input1c: (OpB, 1, 2)
321+
//
322+
// Example 3:
323+
//
324+
// def : Pat<(OpC $input0, (either $input1, $input2)), ...>;
325+
//
326+
// $input0: (OpC, 0, nullopt)
327+
// $input1: (OpC, 1, nullopt)
328+
// $input2: (OpC, 2, nullopt)
329+
//
330+
// Example 4:
331+
//
332+
// def ThreeResultOp : TEST_Op<...> {
333+
// let results = (outs
334+
// AnyType:$result1,
335+
// AnyType:$result2,
336+
// AnyType:$result3
337+
// );
338+
// }
339+
//
340+
// def : Pat<...,
341+
// (ThreeResultOp:$result ...)>;
342+
//
343+
// $result: (nullptr, 3, nullopt)
344+
//
345+
struct DagAndConstant {
346+
// DagNode and DagLeaf are accessed by value which means it can't be used
347+
// as identifier here. Use an opaque pointer type instead.
348+
const void *dag;
349+
int operandIndexOrNumValues;
350+
std::optional<int> variadicSubIndex;
351+
352+
DagAndConstant(const void *dag, int operandIndexOrNumValues,
353+
std::optional<int> variadicSubIndex)
354+
: dag(dag), operandIndexOrNumValues(operandIndexOrNumValues),
355+
variadicSubIndex(variadicSubIndex) {}
356+
357+
bool operator==(const DagAndConstant &rhs) const {
358+
return dag == rhs.dag &&
359+
operandIndexOrNumValues == rhs.operandIndexOrNumValues &&
360+
variadicSubIndex == rhs.variadicSubIndex;
361+
}
362+
};
274363

275364
// What kind of entity this symbol represents:
276365
// * Attr: op attribute
@@ -288,14 +377,18 @@ class SymbolInfoMap {
288377

289378
// Static methods for creating SymbolInfo.
290379
static SymbolInfo getAttr(const Operator *op, int index) {
291-
return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
380+
return SymbolInfo(op, Kind::Attr,
381+
DagAndConstant(nullptr, index, std::nullopt));
292382
}
293383
static SymbolInfo getAttr() {
294384
return SymbolInfo(nullptr, Kind::Attr, std::nullopt);
295385
}
296-
static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
386+
static SymbolInfo
387+
getOperand(DagNode node, const Operator *op, int operandIndex,
388+
std::optional<int> variadicSubIndex = std::nullopt) {
297389
return SymbolInfo(op, Kind::Operand,
298-
DagAndConstant(node.getAsOpaquePointer(), index));
390+
DagAndConstant(node.getAsOpaquePointer(), operandIndex,
391+
variadicSubIndex));
299392
}
300393
static SymbolInfo getResult(const Operator *op) {
301394
return SymbolInfo(op, Kind::Result, std::nullopt);
@@ -305,7 +398,7 @@ class SymbolInfoMap {
305398
}
306399
static SymbolInfo getMultipleValues(int numValues) {
307400
return SymbolInfo(nullptr, Kind::MultipleValues,
308-
DagAndConstant(nullptr, numValues));
401+
DagAndConstant(nullptr, numValues, std::nullopt));
309402
}
310403

311404
// Returns the number of static values this symbol corresponds to.
@@ -333,18 +426,23 @@ class SymbolInfoMap {
333426
const char *separator) const;
334427

335428
// The argument index (for `Attr` and `Operand` only)
336-
int getArgIndex() const { return (*dagAndConstant).second; }
429+
int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; }
337430

338431
// The number of values in the MultipleValue
339-
int getSize() const { return (*dagAndConstant).second; }
432+
int getSize() const { return dagAndConstant->operandIndexOrNumValues; }
433+
434+
// The variadic sub-operands index (for variadic `Operand` only)
435+
std::optional<int> getVariadicSubIndex() const {
436+
return dagAndConstant->variadicSubIndex;
437+
}
340438

341439
const Operator *op; // The op where the bound entity belongs
342440
Kind kind; // The kind of the bound entity
343441

344-
// The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
345-
// the size of MultipleValue symbol). Note that operands may be bound to the
346-
// same symbol, use the DagNode and index to distinguish them. For `Attr`
347-
// and MultipleValue, the Dag part will be nullptr.
442+
// The tuple of DagNode pointer and two constant values (for `Attr`,
443+
// `Operand` and the size of MultipleValue symbol). Note that operands may
444+
// be bound to the same symbol, use the DagNode and index to distinguish
445+
// them. For `Attr` and MultipleValue, the Dag part will be nullptr.
348446
std::optional<DagAndConstant> dagAndConstant;
349447

350448
// Alternative name for the symbol. It is used in case the name
@@ -367,7 +465,8 @@ class SymbolInfoMap {
367465
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
368466
// Returns false if `symbol` is already bound and symbols are not operands.
369467
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
370-
int argIndex);
468+
int argIndex,
469+
std::optional<int> variadicSubIndex = std::nullopt);
371470

372471
// Binds the given `symbol` to the results the given `op`. Returns false if
373472
// `symbol` is already bound.
@@ -397,7 +496,8 @@ class SymbolInfoMap {
397496
// Returns an iterator to the information of the given symbol named as `key`,
398497
// with index `argIndex` for operator `op`.
399498
const_iterator findBoundSymbol(StringRef key, DagNode node,
400-
const Operator &op, int argIndex) const;
499+
const Operator &op, int argIndex,
500+
std::optional<int> variadicSubIndex) const;
401501
const_iterator findBoundSymbol(StringRef key,
402502
const SymbolInfo &symbolInfo) const;
403503

mlir/lib/TableGen/Pattern.cpp

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ bool DagNode::isNativeCodeCall() const {
115115

116116
bool DagNode::isOperation() const {
117117
return !isNativeCodeCall() && !isReplaceWithValue() &&
118-
!isLocationDirective() && !isReturnTypeDirective() && !isEither();
118+
!isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
119+
!isVariadic();
119120
}
120121

121122
llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -193,6 +194,11 @@ bool DagNode::isEither() const {
193194
return dagOpDef->getName() == "either";
194195
}
195196

197+
bool DagNode::isVariadic() const {
198+
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
199+
return dagOpDef->getName() == "variadic";
200+
}
201+
196202
void DagNode::print(raw_ostream &os) const {
197203
if (node)
198204
node->print(os);
@@ -296,9 +302,10 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
296302
case Kind::Operand: {
297303
assert(index < 0);
298304
auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
299-
// If this operand is variadic, then return a range. Otherwise, return the
300-
// value itself.
301-
if (operand->isVariableLength()) {
305+
// If this operand is variadic and this SymbolInfo doesn't have a range
306+
// index, then return the full variadic operand_range. Otherwise, return
307+
// the value itself.
308+
if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
302309
auto repl = formatv(fmt, name);
303310
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
304311
return std::string(repl);
@@ -426,17 +433,19 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
426433
}
427434

428435
bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
429-
const Operator &op, int argIndex) {
436+
const Operator &op, int argIndex,
437+
std::optional<int> variadicSubIndex) {
430438
StringRef name = getValuePackName(symbol);
431439
if (name != symbol) {
432440
auto error = formatv(
433441
"symbol '{0}' with trailing index cannot bind to op argument", symbol);
434442
PrintFatalError(loc, error);
435443
}
436444

437-
auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
438-
? SymbolInfo::getAttr(&op, argIndex)
439-
: SymbolInfo::getOperand(node, &op, argIndex);
445+
auto symInfo =
446+
op.getArg(argIndex).is<NamedAttribute *>()
447+
? SymbolInfo::getAttr(&op, argIndex)
448+
: SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
440449

441450
std::string key = symbol.str();
442451
if (symbolInfoMap.count(key)) {
@@ -499,8 +508,10 @@ SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
499508

500509
SymbolInfoMap::const_iterator
501510
SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
502-
int argIndex) const {
503-
return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
511+
int argIndex,
512+
std::optional<int> variadicSubIndex) const {
513+
return findBoundSymbol(
514+
key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
504515
}
505516

506517
SymbolInfoMap::const_iterator
@@ -831,6 +842,33 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
831842
}
832843
};
833844

845+
// The operand in `variadic` DAG should be bound to the operation in the
846+
// parent DagNode. The range index must be included as well to distinguish
847+
// (potentially) repeating argName within the `variadic` DAG.
848+
auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
849+
int opArgIdx) {
850+
auto treeName = tree.getSymbol();
851+
if (!treeName.empty()) {
852+
// If treeName is specified, bind to the full variadic operand_range.
853+
verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
854+
std::nullopt),
855+
treeName);
856+
}
857+
858+
for (int i = 0; i < tree.getNumArgs(); ++i) {
859+
if (DagNode subTree = tree.getArgAsNestedDag(i)) {
860+
collectBoundSymbols(subTree, infoMap, isSrcPattern);
861+
} else {
862+
auto argName = tree.getArgName(i);
863+
if (!argName.empty() && argName != "_") {
864+
verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
865+
/*variadicSubIndex=*/i),
866+
argName);
867+
}
868+
}
869+
}
870+
};
871+
834872
for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
835873
if (auto treeArg = tree.getArgAsNestedDag(i)) {
836874
if (treeArg.isEither()) {
@@ -843,6 +881,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
843881
//
844882
// (FooOp arg0, arg1, arg2)
845883
++opArgIdx;
884+
} else if (treeArg.isVariadic()) {
885+
collectSymbolInVariadic(tree, treeArg, opArgIdx);
846886
} else {
847887
// This DAG node argument is a DAG node itself. Go inside recursively.
848888
collectBoundSymbols(treeArg, infoMap, isSrcPattern);

0 commit comments

Comments
 (0)