2222#include " llvm/ADT/StringMap.h"
2323#include " llvm/ADT/StringSet.h"
2424
25+ #include < optional>
2526#include < unordered_map>
2627
2728namespace llvm {
@@ -189,6 +190,9 @@ class DagNode {
189190 // Returns whether this DAG is an `either` specifier.
190191 bool isEither () const ;
191192
193+ // Returns whether this DAG is an `variadic` specifier.
194+ bool isVariadic () const ;
195+
192196 // Returns true if this DAG node is an operation.
193197 bool isOperation () const ;
194198
@@ -268,9 +272,94 @@ class SymbolInfoMap {
268272 // Allow SymbolInfoMap to access private methods.
269273 friend class SymbolInfoMap ;
270274
271- // DagNode and DagLeaf are accessed by value which means it can't be used as
272- // identifier here. Use an opaque pointer type instead.
273- using DagAndConstant = std::pair<const void *, int >;
275+ // Structure to uniquely distinguish different locations of the symbols.
276+ //
277+ // * If a symbol is defined as an operand of an operation, `dag` specifies
278+ // the DAG of the operation, `operandIndexOrNumValues` specifies the
279+ // operand index, and `variadicSubIndex` must be set to `std::nullopt`.
280+ //
281+ // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG
282+ // of the parent operation, `operandIndexOrNumValues` specifies the
283+ // declared operand index of the variadic operand in the parent
284+ // operation.
285+ //
286+ // - If the symbol is defined as a result of `variadic` DAG, the
287+ // `variadicSubIndex` must be set to `std::nullopt`, which means that
288+ // the symbol binds to the full operand range.
289+ //
290+ // - If the symbol is defined as a operand, the `variadicSubIndex` must
291+ // be set to the index within the variadic sub-operand list.
292+ //
293+ // * If a symbol is defined in a `either` DAG, `dag` specifies the DAG
294+ // of the parent operation, `operandIndexOrNumValues` specifies the
295+ // operand index in the parent operation (not necessary the index in the
296+ // DAG).
297+ //
298+ // * If a symbol is defined as a result, specifies the number of returning
299+ // value.
300+ //
301+ // Example 1:
302+ //
303+ // def : Pat<(OpA $input0, $input1), ...>;
304+ //
305+ // $input0: (OpA, 0, nullopt)
306+ // $input1: (OpA, 1, nullopt)
307+ //
308+ // Example 2:
309+ //
310+ // def : Pat<(OpB (variadic:$input0 $input0a, $input0b),
311+ // (variadic:$input1 $input1a, $input1b, $input1c)),
312+ // ...>;
313+ //
314+ // $input0: (OpB, 0, nullopt)
315+ // $input0a: (OpB, 0, 0)
316+ // $input0b: (OpB, 0, 1)
317+ // $input1: (OpB, 1, nullopt)
318+ // $input1a: (OpB, 1, 0)
319+ // $input1b: (OpB, 1, 1)
320+ // $input1c: (OpB, 1, 2)
321+ //
322+ // Example 3:
323+ //
324+ // def : Pat<(OpC $input0, (either $input1, $input2)), ...>;
325+ //
326+ // $input0: (OpC, 0, nullopt)
327+ // $input1: (OpC, 1, nullopt)
328+ // $input2: (OpC, 2, nullopt)
329+ //
330+ // Example 4:
331+ //
332+ // def ThreeResultOp : TEST_Op<...> {
333+ // let results = (outs
334+ // AnyType:$result1,
335+ // AnyType:$result2,
336+ // AnyType:$result3
337+ // );
338+ // }
339+ //
340+ // def : Pat<...,
341+ // (ThreeResultOp:$result ...)>;
342+ //
343+ // $result: (nullptr, 3, nullopt)
344+ //
345+ struct DagAndConstant {
346+ // DagNode and DagLeaf are accessed by value which means it can't be used
347+ // as identifier here. Use an opaque pointer type instead.
348+ const void *dag;
349+ int operandIndexOrNumValues;
350+ std::optional<int > variadicSubIndex;
351+
352+ DagAndConstant (const void *dag, int operandIndexOrNumValues,
353+ std::optional<int > variadicSubIndex)
354+ : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues),
355+ variadicSubIndex (variadicSubIndex) {}
356+
357+ bool operator ==(const DagAndConstant &rhs) const {
358+ return dag == rhs.dag &&
359+ operandIndexOrNumValues == rhs.operandIndexOrNumValues &&
360+ variadicSubIndex == rhs.variadicSubIndex ;
361+ }
362+ };
274363
275364 // What kind of entity this symbol represents:
276365 // * Attr: op attribute
@@ -288,14 +377,18 @@ class SymbolInfoMap {
288377
289378 // Static methods for creating SymbolInfo.
290379 static SymbolInfo getAttr (const Operator *op, int index) {
291- return SymbolInfo (op, Kind::Attr, DagAndConstant (nullptr , index));
380+ return SymbolInfo (op, Kind::Attr,
381+ DagAndConstant (nullptr , index, std::nullopt ));
292382 }
293383 static SymbolInfo getAttr () {
294384 return SymbolInfo (nullptr , Kind::Attr, std::nullopt );
295385 }
296- static SymbolInfo getOperand (DagNode node, const Operator *op, int index) {
386+ static SymbolInfo
387+ getOperand (DagNode node, const Operator *op, int operandIndex,
388+ std::optional<int > variadicSubIndex = std::nullopt ) {
297389 return SymbolInfo (op, Kind::Operand,
298- DagAndConstant (node.getAsOpaquePointer (), index));
390+ DagAndConstant (node.getAsOpaquePointer (), operandIndex,
391+ variadicSubIndex));
299392 }
300393 static SymbolInfo getResult (const Operator *op) {
301394 return SymbolInfo (op, Kind::Result, std::nullopt );
@@ -305,7 +398,7 @@ class SymbolInfoMap {
305398 }
306399 static SymbolInfo getMultipleValues (int numValues) {
307400 return SymbolInfo (nullptr , Kind::MultipleValues,
308- DagAndConstant (nullptr , numValues));
401+ DagAndConstant (nullptr , numValues, std:: nullopt ));
309402 }
310403
311404 // Returns the number of static values this symbol corresponds to.
@@ -333,18 +426,23 @@ class SymbolInfoMap {
333426 const char *separator) const ;
334427
335428 // The argument index (for `Attr` and `Operand` only)
336- int getArgIndex () const { return (* dagAndConstant). second ; }
429+ int getArgIndex () const { return dagAndConstant-> operandIndexOrNumValues ; }
337430
338431 // The number of values in the MultipleValue
339- int getSize () const { return (*dagAndConstant).second ; }
432+ int getSize () const { return dagAndConstant->operandIndexOrNumValues ; }
433+
434+ // The variadic sub-operands index (for variadic `Operand` only)
435+ std::optional<int > getVariadicSubIndex () const {
436+ return dagAndConstant->variadicSubIndex ;
437+ }
340438
341439 const Operator *op; // The op where the bound entity belongs
342440 Kind kind; // The kind of the bound entity
343441
344- // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
345- // the size of MultipleValue symbol). Note that operands may be bound to the
346- // same symbol, use the DagNode and index to distinguish them. For `Attr`
347- // and MultipleValue, the Dag part will be nullptr.
442+ // The tuple of DagNode pointer and two constant values (for `Attr`,
443+ // `Operand` and the size of MultipleValue symbol). Note that operands may
444+ // be bound to the same symbol, use the DagNode and index to distinguish
445+ // them. For `Attr` and MultipleValue, the Dag part will be nullptr.
348446 std::optional<DagAndConstant> dagAndConstant;
349447
350448 // Alternative name for the symbol. It is used in case the name
@@ -367,7 +465,8 @@ class SymbolInfoMap {
367465 // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
368466 // Returns false if `symbol` is already bound and symbols are not operands.
369467 bool bindOpArgument (DagNode node, StringRef symbol, const Operator &op,
370- int argIndex);
468+ int argIndex,
469+ std::optional<int > variadicSubIndex = std::nullopt );
371470
372471 // Binds the given `symbol` to the results the given `op`. Returns false if
373472 // `symbol` is already bound.
@@ -397,7 +496,8 @@ class SymbolInfoMap {
397496 // Returns an iterator to the information of the given symbol named as `key`,
398497 // with index `argIndex` for operator `op`.
399498 const_iterator findBoundSymbol (StringRef key, DagNode node,
400- const Operator &op, int argIndex) const ;
499+ const Operator &op, int argIndex,
500+ std::optional<int > variadicSubIndex) const ;
401501 const_iterator findBoundSymbol (StringRef key,
402502 const SymbolInfo &symbolInfo) const ;
403503
0 commit comments