Skip to content

Commit

Permalink
TableGen/ISel: Allow PatFrag predicate code to access captured operands
Browse files Browse the repository at this point in the history
Summary:
This simplifies writing predicates for pattern fragments that are
automatically re-associated or commuted.

For example, a followup patch adds patterns for fragments of the form
(add (shl $x, $y), $z) to the AMDGPU backend. Such patterns are
automatically commuted to (add $z, (shl $x, $y)), which makes it basically
impossible to refer to $x, $y, and $z generically in the PredicateCode.

With this change, the PredicateCode can refer to $x, $y, and $z simply
as `Operands[i]`.

Test confirmed that there are no changes to any of the generated files
when building all (non-experimental) targets.

Change-Id: I61c00ace7eed42c1d4edc4c5351174b56b77a79c

Reviewers: arsenm, rampitec, RKSimon, craig.topper, hfinkel, uweigand

Subscribers: wdng, tpr, llvm-commits

Differential Revision: https://reviews.llvm.org/D51994

llvm-svn: 347992
  • Loading branch information
nhaehnle committed Nov 30, 2018
1 parent 4830fdd commit 445b0b6
Show file tree
Hide file tree
Showing 11 changed files with 314 additions and 91 deletions.
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class SelectionDAGISel : public MachineFunctionPass {
OPC_CheckChild2Same, OPC_CheckChild3Same,
OPC_CheckPatternPredicate,
OPC_CheckPredicate,
OPC_CheckPredicateWithOperands,
OPC_CheckOpcode,
OPC_SwitchOpcode,
OPC_CheckType,
Expand Down Expand Up @@ -267,6 +268,17 @@ class SelectionDAGISel : public MachineFunctionPass {
llvm_unreachable("Tblgen should generate the implementation of this!");
}

/// CheckNodePredicateWithOperands - This function is generated by tblgen in
/// the target.
/// It runs node predicate number PredNo and returns true if it succeeds or
/// false if it fails. The number is a private implementation detail to the
/// code tblgen produces.
virtual bool CheckNodePredicateWithOperands(
SDNode *N, unsigned PredNo,
const SmallVectorImpl<SDValue> &Operands) const {
llvm_unreachable("Tblgen should generate the implementation of this!");
}

virtual bool CheckComplexPattern(SDNode *Root, SDNode *Parent, SDValue N,
unsigned PatternNo,
SmallVectorImpl<std::pair<SDValue, SDNode*> > &Result) {
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,15 @@ class PatFrags<dag ops, list<dag> frags, code pred = [{}],
code ImmediateCode = [{}];
SDNodeXForm OperandTransform = xform;

// When this is set, the PredicateCode may refer to a constant Operands
// vector which contains the captured nodes of the DAG, in the order listed
// by the Operands field above.
//
// This is useful when Fragments involves associative / commutative
// operators: a single piece of code can easily refer to all operands even
// when re-associated / commuted variants of the fragment are matched.
bit PredicateCodeUsesOperands = 0;

// Define a few pre-packaged predicates. This helps GlobalISel import
// existing rules from SelectionDAG for many common cases.
// They will be tested prior to the code in pred and must not be used in
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3207,6 +3207,18 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
N.getNode()))
break;
continue;
case OPC_CheckPredicateWithOperands: {
unsigned OpNum = MatcherTable[MatcherIndex++];
SmallVector<SDValue, 8> Operands;

for (unsigned i = 0; i < OpNum; ++i)
Operands.push_back(RecordedNodes[MatcherTable[MatcherIndex++]].first);

unsigned PredNo = MatcherTable[MatcherIndex++];
if (!CheckNodePredicateWithOperands(N.getNode(), PredNo, Operands))
break;
continue;
}
case OPC_CheckComplexPat: {
unsigned CPNum = MatcherTable[MatcherIndex++];
unsigned RecNo = MatcherTable[MatcherIndex++];
Expand Down
74 changes: 54 additions & 20 deletions llvm/utils/TableGen/CodeGenDAGPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,20 @@ TypeInfer::ValidateOnExit::~ValidateOnExit() {
}
#endif


//===----------------------------------------------------------------------===//
// ScopedName Implementation
//===----------------------------------------------------------------------===//

bool ScopedName::operator==(const ScopedName &o) const {
return Scope == o.Scope && Identifier == o.Identifier;
}

bool ScopedName::operator!=(const ScopedName &o) const {
return !(*this == o);
}


//===----------------------------------------------------------------------===//
// TreePredicateFn Implementation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1069,6 +1083,9 @@ bool TreePredicateFn::isPredefinedPredicateEqualTo(StringRef Field,
return false;
return Result == Value;
}
bool TreePredicateFn::usesOperands() const {
return isPredefinedPredicateEqualTo("PredicateCodeUsesOperands", true);
}
bool TreePredicateFn::isLoad() const {
return isPredefinedPredicateEqualTo("IsLoad", true);
}
Expand Down Expand Up @@ -1250,7 +1267,7 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {
else
Result = " auto *N = cast<" + ClassName.str() + ">(Node);\n";

return Result + getPredCode();
return (Twine(Result) + " (void)N;\n" + getPredCode()).str();
}

//===----------------------------------------------------------------------===//
Expand All @@ -1276,7 +1293,7 @@ static unsigned getPatternSize(const TreePatternNode *P,

// If this node has some predicate function that must match, it adds to the
// complexity of this node.
if (!P->getPredicateFns().empty())
if (!P->getPredicateCalls().empty())
++Size;

// Count children in the count if they are also nodes.
Expand All @@ -1296,7 +1313,7 @@ static unsigned getPatternSize(const TreePatternNode *P,
Size += 5; // Matches a ConstantSDNode (+3) and a specific value (+2).
else if (Child->getComplexPatternInfo(CGP))
Size += getPatternSize(Child, CGP);
else if (!Child->getPredicateFns().empty())
else if (!Child->getPredicateCalls().empty())
++Size;
}
}
Expand Down Expand Up @@ -1751,13 +1768,19 @@ void TreePatternNode::print(raw_ostream &OS) const {
OS << ")";
}

for (const TreePredicateFn &Pred : PredicateFns)
OS << "<<P:" << Pred.getFnName() << ">>";
for (const TreePredicateCall &Pred : PredicateCalls) {
OS << "<<P:";
if (Pred.Scope)
OS << Pred.Scope << ":";
OS << Pred.Fn.getFnName() << ">>";
}
if (TransformFn)
OS << "<<X:" << TransformFn->getName() << ">>";
if (!getName().empty())
OS << ":$" << getName();

for (const ScopedName &Name : NamesAsPredicateArg)
OS << ":$pred:" << Name.getScope() << ":" << Name.getIdentifier();
}
void TreePatternNode::dump() const {
print(errs());
Expand All @@ -1774,7 +1797,7 @@ bool TreePatternNode::isIsomorphicTo(const TreePatternNode *N,
const MultipleUseVarSet &DepVars) const {
if (N == this) return true;
if (N->isLeaf() != isLeaf() || getExtTypes() != N->getExtTypes() ||
getPredicateFns() != N->getPredicateFns() ||
getPredicateCalls() != N->getPredicateCalls() ||
getTransformFn() != N->getTransformFn())
return false;

Expand Down Expand Up @@ -1812,8 +1835,9 @@ TreePatternNodePtr TreePatternNode::clone() const {
getNumTypes());
}
New->setName(getName());
New->setNamesAsPredicateArg(getNamesAsPredicateArg());
New->Types = Types;
New->setPredicateFns(getPredicateFns());
New->setPredicateCalls(getPredicateCalls());
New->setTransformFn(getTransformFn());
return New;
}
Expand Down Expand Up @@ -1845,8 +1869,8 @@ void TreePatternNode::SubstituteFormalArguments(
// We found a use of a formal argument, replace it with its value.
TreePatternNodePtr NewChild = ArgMap[Child->getName()];
assert(NewChild && "Couldn't find formal argument!");
assert((Child->getPredicateFns().empty() ||
NewChild->getPredicateFns() == Child->getPredicateFns()) &&
assert((Child->getPredicateCalls().empty() ||
NewChild->getPredicateCalls() == Child->getPredicateCalls()) &&
"Non-empty child predicate clobbered!");
setChild(i, std::move(NewChild));
}
Expand Down Expand Up @@ -1892,8 +1916,8 @@ void TreePatternNode::InlinePatternFragments(
return;

for (auto NewChild : ChildAlternatives[i])
assert((Child->getPredicateFns().empty() ||
NewChild->getPredicateFns() == Child->getPredicateFns()) &&
assert((Child->getPredicateCalls().empty() ||
NewChild->getPredicateCalls() == Child->getPredicateCalls()) &&
"Non-empty child predicate clobbered!");
}

Expand All @@ -1911,7 +1935,8 @@ void TreePatternNode::InlinePatternFragments(

// Copy over properties.
R->setName(getName());
R->setPredicateFns(getPredicateFns());
R->setNamesAsPredicateArg(getNamesAsPredicateArg());
R->setPredicateCalls(getPredicateCalls());
R->setTransformFn(getTransformFn());
for (unsigned i = 0, e = getNumTypes(); i != e; ++i)
R->setType(i, getExtType(i));
Expand Down Expand Up @@ -1946,20 +1971,28 @@ void TreePatternNode::InlinePatternFragments(
return;
}

TreePredicateFn PredFn(Frag);
unsigned Scope = 0;
if (TreePredicateFn(Frag).usesOperands())
Scope = TP.getDAGPatterns().allocateScope();

// Compute the map of formal to actual arguments.
std::map<std::string, TreePatternNodePtr> ArgMap;
for (unsigned i = 0, e = Frag->getNumArgs(); i != e; ++i) {
const TreePatternNodePtr &Child = getChildShared(i);
TreePatternNodePtr Child = getChildShared(i);
if (Scope != 0) {
Child = Child->clone();
Child->addNameAsPredicateArg(ScopedName(Scope, Frag->getArgName(i)));
}
ArgMap[Frag->getArgName(i)] = Child;
}

// Loop over all fragment alternatives.
for (auto Alternative : Frag->getTrees()) {
TreePatternNodePtr FragTree = Alternative->clone();

TreePredicateFn PredFn(Frag);
if (!PredFn.isAlwaysTrue())
FragTree->addPredicateFn(PredFn);
FragTree->addPredicateCall(PredFn, Scope);

// Resolve formal arguments to their actual value.
if (Frag->getNumArgs())
Expand All @@ -1972,8 +2005,8 @@ void TreePatternNode::InlinePatternFragments(
FragTree->UpdateNodeType(i, getExtType(i), TP);

// Transfer in the old predicates.
for (const TreePredicateFn &Pred : getPredicateFns())
FragTree->addPredicateFn(Pred);
for (const TreePredicateCall &Pred : getPredicateCalls())
FragTree->addPredicateCall(Pred);

// The fragment we inlined could have recursive inlining that is needed. See
// if there are any pattern fragments in it and inline them as needed.
Expand Down Expand Up @@ -3596,7 +3629,7 @@ void CodeGenDAGPatterns::parseInstructionPattern(
TreePatternNodePtr OpNode = InVal->clone();

// No predicate is useful on the result.
OpNode->clearPredicateFns();
OpNode->clearPredicateCalls();

// Promote the xform function to be an explicit node if set.
if (Record *Xform = OpNode->getTransformFn()) {
Expand Down Expand Up @@ -4251,7 +4284,8 @@ static void CombineChildVariants(

// Copy over properties.
R->setName(Orig->getName());
R->setPredicateFns(Orig->getPredicateFns());
R->setNamesAsPredicateArg(Orig->getNamesAsPredicateArg());
R->setPredicateCalls(Orig->getPredicateCalls());
R->setTransformFn(Orig->getTransformFn());
for (unsigned i = 0, e = Orig->getNumTypes(); i != e; ++i)
R->setType(i, Orig->getExtType(i));
Expand Down Expand Up @@ -4303,7 +4337,7 @@ GatherChildrenOfAssociativeOpcode(TreePatternNodePtr N,
Record *Operator = N->getOperator();

// Only permit raw nodes.
if (!N->getName().empty() || !N->getPredicateFns().empty() ||
if (!N->getName().empty() || !N->getPredicateCalls().empty() ||
N->getTransformFn()) {
Children.push_back(N);
return;
Expand Down
Loading

0 comments on commit 445b0b6

Please sign in to comment.