diff --git a/llvm/lib/TargetParser/RISCVISAInfo.cpp b/llvm/lib/TargetParser/RISCVISAInfo.cpp index ea0b56b9a1339..ac0f958dbb264 100644 --- a/llvm/lib/TargetParser/RISCVISAInfo.cpp +++ b/llvm/lib/TargetParser/RISCVISAInfo.cpp @@ -847,15 +847,21 @@ Error RISCVISAInfo::checkDependency() { struct ImpliedExtsEntry { StringLiteral Name; - ArrayRef Exts; + const char *ImpliedExt; bool operator<(const ImpliedExtsEntry &Other) const { return Name < Other.Name; } - - bool operator<(StringRef Other) const { return Name < Other; } }; +static bool operator<(const ImpliedExtsEntry &LHS, StringRef RHS) { + return LHS.Name < RHS; +} + +static bool operator<(StringRef LHS, const ImpliedExtsEntry &RHS) { + return LHS < RHS.Name; +} + #define GET_IMPLIED_EXTENSIONS #include "llvm/TargetParser/RISCVTargetParserDef.inc" @@ -880,18 +886,19 @@ void RISCVISAInfo::updateImplication() { while (!WorkList.empty()) { StringRef ExtName = WorkList.pop_back_val(); - auto I = llvm::lower_bound(ImpliedExts, ExtName); - if (I != std::end(ImpliedExts) && I->Name == ExtName) { - for (const char *ImpliedExt : I->Exts) { - if (WorkList.count(ImpliedExt)) - continue; - if (Exts.count(ImpliedExt)) - continue; - auto Version = findDefaultVersion(ImpliedExt); - addExtension(ImpliedExt, Version.value()); - WorkList.insert(ImpliedExt); - } - } + auto Range = std::equal_range(std::begin(ImpliedExts), + std::end(ImpliedExts), ExtName); + std::for_each(Range.first, Range.second, + [&](const ImpliedExtsEntry &Implied) { + const char *ImpliedExt = Implied.ImpliedExt; + if (WorkList.count(ImpliedExt)) + return; + if (Exts.count(ImpliedExt)) + return; + auto Version = findDefaultVersion(ImpliedExt); + addExtension(ImpliedExt, Version.value()); + WorkList.insert(ImpliedExt); + }); } // Add Zcf if Zce and F are enabled on RV32. @@ -902,42 +909,34 @@ void RISCVISAInfo::updateImplication() { } } -struct CombinedExtsEntry { - StringLiteral CombineExt; - ArrayRef RequiredExts; -}; - -static constexpr CombinedExtsEntry CombineIntoExts[] = { - {{"zk"}, {ImpliedExtsZk}}, - {{"zkn"}, {ImpliedExtsZkn}}, - {{"zks"}, {ImpliedExtsZks}}, - {{"zvkn"}, {ImpliedExtsZvkn}}, - {{"zvknc"}, {ImpliedExtsZvknc}}, - {{"zvkng"}, {ImpliedExtsZvkng}}, - {{"zvks"}, {ImpliedExtsZvks}}, - {{"zvksc"}, {ImpliedExtsZvksc}}, - {{"zvksg"}, {ImpliedExtsZvksg}}, +static constexpr StringLiteral CombineIntoExts[] = { + {"zk"}, {"zkn"}, {"zks"}, {"zvkn"}, {"zvknc"}, + {"zvkng"}, {"zvks"}, {"zvksc"}, {"zvksg"}, }; void RISCVISAInfo::updateCombination() { - bool IsNewCombine = false; + bool MadeChange = false; do { - IsNewCombine = false; - for (CombinedExtsEntry CombineIntoExt : CombineIntoExts) { - auto CombineExt = CombineIntoExt.CombineExt; - auto RequiredExts = CombineIntoExt.RequiredExts; + MadeChange = false; + for (StringRef CombineExt : CombineIntoExts) { if (hasExtension(CombineExt)) continue; - bool IsAllRequiredFeatureExist = true; - for (const char *Ext : RequiredExts) - IsAllRequiredFeatureExist &= hasExtension(Ext); - if (IsAllRequiredFeatureExist) { + + // Look up the extension in the ImpliesExt table to find everything it + // depends on. + auto Range = std::equal_range(std::begin(ImpliedExts), + std::end(ImpliedExts), CombineExt); + bool HasAllRequiredFeatures = std::all_of( + Range.first, Range.second, [&](const ImpliedExtsEntry &Implied) { + return hasExtension(Implied.ImpliedExt); + }); + if (HasAllRequiredFeatures) { auto Version = findDefaultVersion(CombineExt); addExtension(CombineExt, Version.value()); - IsNewCombine = true; + MadeChange = true; } } - } while (IsNewCombine); + } while (MadeChange); } void RISCVISAInfo::updateFLen() { diff --git a/llvm/test/TableGen/riscv-target-def.td b/llvm/test/TableGen/riscv-target-def.td index b23c7e4d40198..01c72e07460e5 100644 --- a/llvm/test/TableGen/riscv-target-def.td +++ b/llvm/test/TableGen/riscv-target-def.td @@ -113,10 +113,8 @@ def ROCKET : RISCVTuneProcessorModel<"rocket", // CHECK: #ifdef GET_IMPLIED_EXTENSIONS // CHECK-NEXT: #undef GET_IMPLIED_EXTENSIONS -// CHECK: static const char *ImpliedExtsF[] = {"zicsr"}; - // CHECK: static constexpr ImpliedExtsEntry ImpliedExts[] = { -// CHECK-NEXT: { {"f"}, {ImpliedExtsF} }, +// CHECK-NEXT: { {"f"}, "zicsr"}, // CHECK-NEXT: }; // CHECK: #endif // GET_IMPLIED_EXTENSIONS diff --git a/llvm/utils/TableGen/RISCVTargetDefEmitter.cpp b/llvm/utils/TableGen/RISCVTargetDefEmitter.cpp index 217b531dcfd39..c34c4b3f1881b 100644 --- a/llvm/utils/TableGen/RISCVTargetDefEmitter.cpp +++ b/llvm/utils/TableGen/RISCVTargetDefEmitter.cpp @@ -43,16 +43,6 @@ static void printExtensionTable(raw_ostream &OS, OS << "};\n\n"; } -// Get the extension name from the Record name. This gives the canonical -// capitalization. -static StringRef getExtensionNameFromRecordName(const Record *R) { - StringRef Name = R->getName(); - if (!Name.consume_front("FeatureStdExt")) - Name.consume_front("FeatureVendor"); - - return Name; -} - static void emitRISCVExtensions(RecordKeeper &Records, raw_ostream &OS) { OS << "#ifdef GET_SUPPORTED_EXTENSIONS\n"; OS << "#undef GET_SUPPORTED_EXTENSIONS\n\n"; @@ -71,33 +61,21 @@ static void emitRISCVExtensions(RecordKeeper &Records, raw_ostream &OS) { OS << "#ifdef GET_IMPLIED_EXTENSIONS\n"; OS << "#undef GET_IMPLIED_EXTENSIONS\n\n"; + OS << "\nstatic constexpr ImpliedExtsEntry ImpliedExts[] = {\n"; for (Record *Ext : Extensions) { auto ImpliesList = Ext->getValueAsListOfDefs("Implies"); if (ImpliesList.empty()) continue; - OS << "static const char *ImpliedExts" - << getExtensionNameFromRecordName(Ext) << "[] = {"; + StringRef Name = getExtensionName(Ext); - ListSeparator LS(", "); for (auto *ImpliedExt : ImpliesList) { if (!ImpliedExt->isSubClassOf("RISCVExtension")) continue; - OS << LS << '"' << getExtensionName(ImpliedExt) << '"'; + OS << " { {\"" << Name << "\"}, \"" << getExtensionName(ImpliedExt) + << "\"},\n"; } - - OS << "};\n"; - } - - OS << "\nstatic constexpr ImpliedExtsEntry ImpliedExts[] = {\n"; - for (Record *Ext : Extensions) { - auto ImpliesList = Ext->getValueAsListOfDefs("Implies"); - if (ImpliesList.empty()) - continue; - - OS << " { {\"" << getExtensionName(Ext) << "\"}, {ImpliedExts" - << getExtensionNameFromRecordName(Ext) << "} },\n"; } OS << "};\n\n";