diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 42edad255ce82..04c54f9b0e53d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -17,7 +17,9 @@ #include "llvm/TargetParser/Triple.h" #include +#include #include +#include #include #include #include @@ -26,7 +28,7 @@ using namespace llvm; -static const std::map> +static const std::map SPIRVExtensionMap = { {"SPV_EXT_shader_atomic_float_add", SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add}, @@ -181,57 +183,52 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, std::set &Vals) { SmallVector Tokens; ArgValue.split(Tokens, ",", -1, false); - llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) { - // We want to ensure that we handle "all" first, to ensure that any - // subsequent disablement actually behaves as expected i.e. given - // --spv-ext=all,-foo, we first enable all and then disable foo; this should - // be revisited and simplified. - if (LHS == "all") - return true; - if (RHS == "all") - return false; - return !(RHS < LHS); - }); std::set EnabledExtensions; - for (const auto &Token : Tokens) { - if (Token == "all") { - for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap) - EnabledExtensions.insert(ExtensionEnum); + auto M = partition(Tokens, [](auto &&T) { return T.starts_with('+'); }); + + if (std::any_of(M, Tokens.end(), [](auto &&T) { return T == "all"; })) + copy(make_second_range(SPIRVExtensionMap), std::inserter(Vals, Vals.end())); + + for (auto &&Token : make_range(Tokens.begin(), M)) { + StringRef ExtensionName = Token.substr(1); + auto NameValuePair = SPIRVExtensionMap.find(ExtensionName); + if (NameValuePair == SPIRVExtensionMap.end()) + return O.error("Unknown SPIR-V extension: " + Token.str()); + + EnabledExtensions.insert(NameValuePair->second); + } + + for (auto &&Token : make_range(M, Tokens.end())) { + if (Token == "all") continue; - } if (Token.size() == 3 && Token.upper() == "KHR") { for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap) if (StringRef(ExtensionName).starts_with("SPV_KHR_")) - EnabledExtensions.insert(ExtensionEnum); + Vals.insert(ExtensionEnum); continue; } if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-"))) - return O.error("Invalid extension list format: " + Token.str()); + return O.error("Invalid extension list format: " + Token); - StringRef ExtensionName = Token.substr(1); - auto NameValuePair = SPIRVExtensionMap.find(ExtensionName); + auto NameValuePair = SPIRVExtensionMap.find(Token.substr(1)); - if (NameValuePair == SPIRVExtensionMap.end()) + if (NameValuePair == SPIRVExtensionMap.cend()) return O.error("Unknown SPIR-V extension: " + Token.str()); + if (EnabledExtensions.count(NameValuePair->second)) + return O.error( + "Extension cannot be allowed and disallowed at the same time: " + + NameValuePair->first); - if (Token.starts_with("+")) { - EnabledExtensions.insert(NameValuePair->second); - } else if (EnabledExtensions.count(NameValuePair->second)) { - if (llvm::is_contained(Tokens, "+" + ExtensionName.str())) - return O.error( - "Extension cannot be allowed and disallowed at the same time: " + - ExtensionName.str()); - - EnabledExtensions.erase(NameValuePair->second); - } + Vals.erase(NameValuePair->second); } - Vals = std::move(EnabledExtensions); + Vals.insert(EnabledExtensions.cbegin(), EnabledExtensions.cend()); + return false; }