Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SelectionDAG] Add space-optimized forms of OPC_CheckPatternPredicate #73319

Merged

Conversation

wangpc-pp
Copy link
Contributor

We record the usage of each PatternPredicate and sort them by
usage.

For the top 8 PatternPredicates, we will emit a
OPC_CheckPatternPredicateN to save one byte.

The old OPC_CheckPatternPredicate2 is renamed to
OPC_CheckPatternPredicateTwoByte.

Overall this reduces the llc binary size with all in-tree targets by
about 93K.

This PR is stacked on #73310.

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Nov 24, 2023
@wangpc-pp wangpc-pp requested review from topperc and ilovepi and removed request for topperc November 24, 2023 11:49
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 24, 2023

@llvm/pr-subscribers-llvm-selectiondag

Author: Wang Pengcheng (wangpc-pp)

Changes

We record the usage of each PatternPredicate and sort them by
usage.

For the top 8 PatternPredicates, we will emit a
OPC_CheckPatternPredicateN to save one byte.

The old OPC_CheckPatternPredicate2 is renamed to
OPC_CheckPatternPredicateTwoByte.

Overall this reduces the llc binary size with all in-tree targets by
about 93K.

This PR is stacked on #73310.


Full diff: https://github.com/llvm/llvm-project/pull/73319.diff

7 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAGISel.h (+16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp (+36-5)
  • (modified) llvm/test/TableGen/dag-isel-complexpattern.td (+1-1)
  • (modified) llvm/utils/TableGen/CodeGenDAGPatterns.h (+23)
  • (modified) llvm/utils/TableGen/DAGISelMatcher.h (+1-1)
  • (modified) llvm/utils/TableGen/DAGISelMatcherEmitter.cpp (+32-20)
  • (modified) llvm/utils/TableGen/DAGISelMatcherGen.cpp (+8-5)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
index e6513eb6abc8749..b2c398112736c22 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
@@ -150,7 +150,15 @@ class SelectionDAGISel : public MachineFunctionPass {
     OPC_CheckChild2Same,
     OPC_CheckChild3Same,
     OPC_CheckPatternPredicate,
+    OPC_CheckPatternPredicate0,
+    OPC_CheckPatternPredicate1,
     OPC_CheckPatternPredicate2,
+    OPC_CheckPatternPredicate3,
+    OPC_CheckPatternPredicate4,
+    OPC_CheckPatternPredicate5,
+    OPC_CheckPatternPredicate6,
+    OPC_CheckPatternPredicate7,
+    OPC_CheckPatternPredicateTwoByte,
     OPC_CheckPredicate,
     OPC_CheckPredicateWithOperands,
     OPC_CheckOpcode,
@@ -176,6 +184,14 @@ class SelectionDAGISel : public MachineFunctionPass {
     OPC_CheckChild2CondCode,
     OPC_CheckValueType,
     OPC_CheckComplexPat,
+    OPC_CheckComplexPat0,
+    OPC_CheckComplexPat1,
+    OPC_CheckComplexPat2,
+    OPC_CheckComplexPat3,
+    OPC_CheckComplexPat4,
+    OPC_CheckComplexPat5,
+    OPC_CheckComplexPat6,
+    OPC_CheckComplexPat7,
     OPC_CheckAndImm,
     OPC_CheckOrImm,
     OPC_CheckImmAllOnesV,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 7d9bebdca127224..f10d05843a372f8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -2690,7 +2690,12 @@ LLVM_ATTRIBUTE_ALWAYS_INLINE static bool CheckChildSame(
 LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
 CheckPatternPredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
                       const SelectionDAGISel &SDISel, bool TwoBytePredNo) {
-  unsigned PredNo = MatcherTable[MatcherIndex++];
+  unsigned Opcode = MatcherTable[MatcherIndex - 1];
+  unsigned PredNo =
+      Opcode == SelectionDAGISel::OPC_CheckPatternPredicate ||
+              Opcode == SelectionDAGISel::OPC_CheckPatternPredicate
+          ? MatcherTable[MatcherIndex++]
+          : Opcode - SelectionDAGISel::OPC_CheckPatternPredicate0;
   if (TwoBytePredNo)
     PredNo |= MatcherTable[MatcherIndex++] << 8;
   return SDISel.CheckPatternPredicate(PredNo);
@@ -2841,10 +2846,18 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
                         Table[Index-1] - SelectionDAGISel::OPC_CheckChild0Same);
     return Index;
   case SelectionDAGISel::OPC_CheckPatternPredicate:
+  case SelectionDAGISel::OPC_CheckPatternPredicate0:
+  case SelectionDAGISel::OPC_CheckPatternPredicate1:
   case SelectionDAGISel::OPC_CheckPatternPredicate2:
+  case SelectionDAGISel::OPC_CheckPatternPredicate3:
+  case SelectionDAGISel::OPC_CheckPatternPredicate4:
+  case SelectionDAGISel::OPC_CheckPatternPredicate5:
+  case SelectionDAGISel::OPC_CheckPatternPredicate6:
+  case SelectionDAGISel::OPC_CheckPatternPredicate7:
+  case SelectionDAGISel::OPC_CheckPatternPredicateTwoByte:
     Result = !::CheckPatternPredicate(
         Table, Index, SDISel,
-        Table[Index - 1] == SelectionDAGISel::OPC_CheckPatternPredicate2);
+        Table[Index - 1] == SelectionDAGISel::OPC_CheckPatternPredicateTwoByte);
     return Index;
   case SelectionDAGISel::OPC_CheckPredicate:
     Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
@@ -3257,9 +3270,17 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
       continue;
 
     case OPC_CheckPatternPredicate:
+    case OPC_CheckPatternPredicate0:
+    case OPC_CheckPatternPredicate1:
     case OPC_CheckPatternPredicate2:
+    case OPC_CheckPatternPredicate3:
+    case OPC_CheckPatternPredicate4:
+    case OPC_CheckPatternPredicate5:
+    case OPC_CheckPatternPredicate6:
+    case OPC_CheckPatternPredicate7:
+    case OPC_CheckPatternPredicateTwoByte:
       if (!::CheckPatternPredicate(MatcherTable, MatcherIndex, *this,
-                                   Opcode == OPC_CheckPatternPredicate2))
+                                   Opcode == OPC_CheckPatternPredicateTwoByte))
         break;
       continue;
     case OPC_CheckPredicate:
@@ -3279,8 +3300,18 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
         break;
       continue;
     }
-    case OPC_CheckComplexPat: {
-      unsigned CPNum = MatcherTable[MatcherIndex++];
+    case OPC_CheckComplexPat:
+    case OPC_CheckComplexPat0:
+    case OPC_CheckComplexPat1:
+    case OPC_CheckComplexPat2:
+    case OPC_CheckComplexPat3:
+    case OPC_CheckComplexPat4:
+    case OPC_CheckComplexPat5:
+    case OPC_CheckComplexPat6:
+    case OPC_CheckComplexPat7: {
+      unsigned CPNum = Opcode == OPC_CheckComplexPat
+                           ? MatcherTable[MatcherIndex++]
+                           : Opcode - OPC_CheckComplexPat0;
       unsigned RecNo = MatcherTable[MatcherIndex++];
       assert(RecNo < RecordedNodes.size() && "Invalid CheckComplexPat");
 
diff --git a/llvm/test/TableGen/dag-isel-complexpattern.td b/llvm/test/TableGen/dag-isel-complexpattern.td
index 40fd03cc8839424..1bb473a9df5a954 100644
--- a/llvm/test/TableGen/dag-isel-complexpattern.td
+++ b/llvm/test/TableGen/dag-isel-complexpattern.td
@@ -22,7 +22,7 @@ def CP32 : ComplexPattern<i32, 0, "SelectCP32">;
 def INSTR : Instruction {
 // CHECK-LABEL: OPC_CheckOpcode, TARGET_VAL(ISD::STORE)
 // CHECK: OPC_CheckType, MVT::i32
-// CHECK: OPC_CheckComplexPat, /*CP*/0, /*#*/1, // SelectCP32:$
+// CHECK: OPC_CheckComplexPat0, /*#*/1, // SelectCP32:$
 // CHECK: Src: (st (add:{ *:[i32] } (CP32:{ *:[i32] }), (CP32:{ *:[i32] })), i64:{ *:[i64] }:$addr)
   let OutOperandList = (outs);
   let InOperandList = (ins GPR64:$addr);
diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.h b/llvm/utils/TableGen/CodeGenDAGPatterns.h
index 2611fe06f55ca53..5155c18a1752461 100644
--- a/llvm/utils/TableGen/CodeGenDAGPatterns.h
+++ b/llvm/utils/TableGen/CodeGenDAGPatterns.h
@@ -1117,6 +1117,12 @@ class CodeGenDAGPatterns {
   std::map<Record*, DAGDefaultOperand, LessRecordByID> DefaultOperands;
   std::map<Record*, DAGInstruction, LessRecordByID> Instructions;
 
+  /// ComplexPatternUsage - Record the usage of ComplexPattern.
+  std::map<const ComplexPattern *, unsigned> ComplexPatternUsage;
+
+  /// PatternPredicateUsage - Record the usage of PatternPredicate.
+  std::map<std::string, unsigned> PatternPredicateUsage;
+
   // Specific SDNode definitions:
   Record *intrinsic_void_sdnode;
   Record *intrinsic_w_chain_sdnode, *intrinsic_wo_chain_sdnode;
@@ -1163,6 +1169,23 @@ class CodeGenDAGPatterns {
     return F->second;
   }
 
+  const std::map<const ComplexPattern *, unsigned> &
+  getComplexPatternUsage() const {
+    return ComplexPatternUsage;
+  }
+
+  void increaseComplexPatternUsage(const ComplexPattern *CP) {
+    ComplexPatternUsage[CP]++;
+  }
+
+  const std::map<std::string, unsigned> &getPatternPredicateUsage() const {
+    return PatternPredicateUsage;
+  }
+
+  void increasePatternPredicateUsage(const std::string &Predicate) {
+    PatternPredicateUsage[Predicate]++;
+  }
+
   const CodeGenIntrinsic &getIntrinsic(Record *R) const {
     for (unsigned i = 0, e = Intrinsics.size(); i != e; ++i)
       if (Intrinsics[i].TheDef == R) return Intrinsics[i];
diff --git a/llvm/utils/TableGen/DAGISelMatcher.h b/llvm/utils/TableGen/DAGISelMatcher.h
index e3cf847edd1273b..9e962b14285c395 100644
--- a/llvm/utils/TableGen/DAGISelMatcher.h
+++ b/llvm/utils/TableGen/DAGISelMatcher.h
@@ -34,7 +34,7 @@ namespace llvm {
   class TreePattern;
 
 Matcher *ConvertPatternToMatcher(const PatternToMatch &Pattern,unsigned Variant,
-                                 const CodeGenDAGPatterns &CGP);
+                                 CodeGenDAGPatterns &CGP);
 void OptimizeMatcher(std::unique_ptr<Matcher> &Matcher,
                      const CodeGenDAGPatterns &CGP);
 void EmitMatcherTable(Matcher *Matcher, const CodeGenDAGPatterns &CGP,
diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index 4a11991036efc11..4482c0f489afe4f 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -60,10 +60,8 @@ class MatcherTableEmitter {
   // all the patterns with "identical" predicates.
   StringMap<TinyPtrVector<TreePattern *>> NodePredicatesByCodeToRun;
 
-  StringMap<unsigned> PatternPredicateMap;
   std::vector<std::string> PatternPredicates;
 
-  DenseMap<const ComplexPattern*, unsigned> ComplexPatternMap;
   std::vector<const ComplexPattern*> ComplexPatterns;
 
 
@@ -85,7 +83,24 @@ class MatcherTableEmitter {
 
 public:
   MatcherTableEmitter(const CodeGenDAGPatterns &cgp)
-      : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {}
+      : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {
+    // Sort ComplexPatterns by usage.
+    auto &Usage = cgp.getComplexPatternUsage();
+    std::vector<std::pair<const ComplexPattern *, unsigned>> PatternList(
+        Usage.begin(), Usage.end());
+    sort(PatternList, [](auto &A, auto &B) { return A.second > B.second; });
+    for (auto &Pattern : PatternList)
+      ComplexPatterns.push_back(Pattern.first);
+
+    // Sort PatternPredicates by usage.
+    auto &PatternPredicateUsage = cgp.getPatternPredicateUsage();
+    std::vector<std::pair<std::string, unsigned>> PatternPredicateList(
+        PatternPredicateUsage.begin(), PatternPredicateUsage.end());
+    sort(PatternPredicateList,
+         [](auto &A, auto &B) { return A.second > B.second; });
+    for (auto &PatternPredicate : PatternPredicateList)
+      PatternPredicates.push_back(PatternPredicate.first);
+  }
 
   unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
                            unsigned StartIdx, raw_ostream &OS);
@@ -138,20 +153,10 @@ class MatcherTableEmitter {
   }
 
   unsigned getPatternPredicate(StringRef PredName) {
-    unsigned &Entry = PatternPredicateMap[PredName];
-    if (Entry == 0) {
-      PatternPredicates.push_back(PredName.str());
-      Entry = PatternPredicates.size();
-    }
-    return Entry-1;
+    return llvm::find(PatternPredicates, PredName) - PatternPredicates.begin();
   }
   unsigned getComplexPat(const ComplexPattern &P) {
-    unsigned &Entry = ComplexPatternMap[&P];
-    if (Entry == 0) {
-      ComplexPatterns.push_back(&P);
-      Entry = ComplexPatterns.size();
-    }
-    return Entry-1;
+    return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin();
   }
 
   unsigned getNodeXFormID(Record *Rec) {
@@ -475,13 +480,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
     StringRef Pred = cast<CheckPatternPredicateMatcher>(N)->getPredicate();
     unsigned PredNo = getPatternPredicate(Pred);
     if (PredNo > 255)
-      OS << "OPC_CheckPatternPredicate2, TARGET_VAL(" << PredNo << "),";
+      OS << "OPC_CheckPatternPredicateTwoByte, TARGET_VAL(" << PredNo << "),";
+    else if (PredNo < 8)
+      OS << "OPC_CheckPatternPredicate" << PredNo << ',';
     else
       OS << "OPC_CheckPatternPredicate, " << PredNo << ',';
     if (!OmitComments)
       OS << " // " << Pred;
     OS << '\n';
-    return 2 + (PredNo > 255);
+    return 2 + (PredNo > 255) - (PredNo < 8);
   }
   case Matcher::CheckPredicate: {
     TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
@@ -625,8 +632,13 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
   case Matcher::CheckComplexPat: {
     const CheckComplexPatMatcher *CCPM = cast<CheckComplexPatMatcher>(N);
     const ComplexPattern &Pattern = CCPM->getPattern();
-    OS << "OPC_CheckComplexPat, /*CP*/" << getComplexPat(Pattern) << ", /*#*/"
-       << CCPM->getMatchNumber() << ',';
+    unsigned PatternNo = getComplexPat(Pattern);
+    if (PatternNo < 8)
+      OS << "OPC_CheckComplexPat" << PatternNo << ", /*#*/"
+         << CCPM->getMatchNumber() << ',';
+    else
+      OS << "OPC_CheckComplexPat, /*CP*/" << PatternNo << ", /*#*/"
+         << CCPM->getMatchNumber() << ',';
 
     if (!OmitComments) {
       OS << " // " << Pattern.getSelectFunc();
@@ -638,7 +650,7 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
         OS << " + chain result";
     }
     OS << '\n';
-    return 3;
+    return PatternNo < 8 ? 2 : 3;
   }
 
   case Matcher::CheckAndImm: {
diff --git a/llvm/utils/TableGen/DAGISelMatcherGen.cpp b/llvm/utils/TableGen/DAGISelMatcherGen.cpp
index d08f57b84b95f08..e305ee7a5f861ed 100644
--- a/llvm/utils/TableGen/DAGISelMatcherGen.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherGen.cpp
@@ -56,7 +56,7 @@ static MVT::SimpleValueType getRegisterValueType(Record *R,
 namespace {
   class MatcherGen {
     const PatternToMatch &Pattern;
-    const CodeGenDAGPatterns &CGP;
+    CodeGenDAGPatterns &CGP;
 
     /// PatWithNoTypes - This is a clone of Pattern.getSrcPattern() that starts
     /// out with all of the types removed.  This allows us to insert type checks
@@ -102,7 +102,7 @@ namespace {
     /// which should have future checks stuck into its Next position.
     Matcher *CurPredicate;
   public:
-    MatcherGen(const PatternToMatch &pattern, const CodeGenDAGPatterns &cgp);
+    MatcherGen(const PatternToMatch &pattern, CodeGenDAGPatterns &cgp);
 
     bool EmitMatcherCode(unsigned Variant);
     void EmitResultCode();
@@ -146,7 +146,7 @@ namespace {
 } // end anonymous namespace
 
 MatcherGen::MatcherGen(const PatternToMatch &pattern,
-                       const CodeGenDAGPatterns &cgp)
+                       CodeGenDAGPatterns &cgp)
 : Pattern(pattern), CGP(cgp), NextRecordedOperandNo(0),
   TheMatcher(nullptr), CurPredicate(nullptr) {
   // We need to produce the matcher tree for the patterns source pattern.  To do
@@ -572,8 +572,10 @@ bool MatcherGen::EmitMatcherCode(unsigned Variant) {
   // If the pattern has a predicate on it (e.g. only enabled when a subtarget
   // feature is around, do the check).
   std::string PredicateCheck = Pattern.getPredicateCheck();
-  if (!PredicateCheck.empty())
+  if (!PredicateCheck.empty()) {
+    CGP.increasePatternPredicateUsage(PredicateCheck);
     AddMatcher(new CheckPatternPredicateMatcher(PredicateCheck));
+  }
 
   // Now that we've completed the structural type match, emit any ComplexPattern
   // checks (e.g. addrmode matches).  We emit this after the structural match
@@ -601,6 +603,7 @@ bool MatcherGen::EmitMatcherCode(unsigned Variant) {
 
     // Emit a CheckComplexPat operation, which does the match (aborting if it
     // fails) and pushes the matched operands onto the recorded nodes list.
+    CGP.increaseComplexPatternUsage(CP);
     AddMatcher(new CheckComplexPatMatcher(*CP, RecNodeEntry, N->getName(),
                                           NextRecordedOperandNo));
 
@@ -1081,7 +1084,7 @@ void MatcherGen::EmitResultCode() {
 /// the specified variant.  If the variant number is invalid, this returns null.
 Matcher *llvm::ConvertPatternToMatcher(const PatternToMatch &Pattern,
                                        unsigned Variant,
-                                       const CodeGenDAGPatterns &CGP) {
+                                       CodeGenDAGPatterns &CGP) {
   MatcherGen Gen(Pattern, CGP);
 
   // Generate the code for the matcher.

Copy link

github-actions bot commented Nov 24, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@wangpc-pp
Copy link
Contributor Author

Ping.

@wangpc-pp wangpc-pp force-pushed the main-matcher-table-check-pattern-predicate branch from 4ceedce to b4ab8e2 Compare December 12, 2023 11:41
@wangpc-pp
Copy link
Contributor Author

Ping.

1 similar comment
@wangpc-pp
Copy link
Contributor Author

Ping.

llvm/utils/TableGen/DAGISelMatcherEmitter.cpp Outdated Show resolved Hide resolved
llvm/utils/TableGen/DAGISelMatcherEmitter.cpp Outdated Show resolved Hide resolved
llvm/utils/TableGen/DAGISelMatcherGen.cpp Outdated Show resolved Hide resolved
llvm/utils/TableGen/DAGISelMatcherGen.cpp Outdated Show resolved Hide resolved
@wangpc-pp wangpc-pp force-pushed the main-matcher-table-check-pattern-predicate branch 2 times, most recently from 97cc38f to 21ff8a1 Compare January 8, 2024 08:35
We record the usage of each `PatternPredicate` and sort them by
usage.

For the top 8 `PatternPredicate`s, we will emit a
`OPC_CheckPatternPredicateN` to save one byte.

The old `OPC_CheckPatternPredicate2` is renamed to
`OPC_CheckPatternPredicateTwoByte`.

Overall this reduces the llc binary size with all in-tree targets by
about 93K.
@wangpc-pp wangpc-pp force-pushed the main-matcher-table-check-pattern-predicate branch from 21ff8a1 to a9e617e Compare January 11, 2024 07:35
@wangpc-pp wangpc-pp merged commit 5c8d123 into llvm:main Jan 11, 2024
3 of 4 checks passed
@wangpc-pp wangpc-pp deleted the main-matcher-table-check-pattern-predicate branch January 11, 2024 07:36
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…llvm#73319)

We record the usage of each `PatternPredicate` and sort them by
usage.

For the top 8 `PatternPredicate`s, we will emit a
`OPC_CheckPatternPredicateN` to save one byte.

The old `OPC_CheckPatternPredicate2` is renamed to
`OPC_CheckPatternPredicateTwoByte`.

Overall this reduces the llc binary size with all in-tree targets by
about 93K.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants