-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[NFC][SPIRV] Re-work extension parsing #171826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…v_be_staging_11
…v_be_staging_11
…v_be_staging_11
|
@llvm/pr-subscribers-backend-spir-v Author: Alex Voicu (AlexVlx) ChangesThis changes the extension parsing mechanism underpinning Full diff: https://github.com/llvm/llvm-project/pull/171826.diff 1 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 42edad255ce82..04c54f9b0e53d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -17,7 +17,9 @@
#include "llvm/TargetParser/Triple.h"
#include <functional>
+#include <iterator>
#include <map>
+#include <set>
#include <string>
#include <utility>
#include <vector>
@@ -26,7 +28,7 @@
using namespace llvm;
-static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
+static const std::map<StringRef, SPIRV::Extension::Extension>
SPIRVExtensionMap = {
{"SPV_EXT_shader_atomic_float_add",
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add},
@@ -181,57 +183,52 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
std::set<SPIRV::Extension::Extension> &Vals) {
SmallVector<StringRef, 10> Tokens;
ArgValue.split(Tokens, ",", -1, false);
- llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
- // We want to ensure that we handle "all" first, to ensure that any
- // subsequent disablement actually behaves as expected i.e. given
- // --spv-ext=all,-foo, we first enable all and then disable foo; this should
- // be revisited and simplified.
- if (LHS == "all")
- return true;
- if (RHS == "all")
- return false;
- return !(RHS < LHS);
- });
std::set<SPIRV::Extension::Extension> EnabledExtensions;
- for (const auto &Token : Tokens) {
- if (Token == "all") {
- for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
- EnabledExtensions.insert(ExtensionEnum);
+ auto M = partition(Tokens, [](auto &&T) { return T.starts_with('+'); });
+
+ if (std::any_of(M, Tokens.end(), [](auto &&T) { return T == "all"; }))
+ copy(make_second_range(SPIRVExtensionMap), std::inserter(Vals, Vals.end()));
+
+ for (auto &&Token : make_range(Tokens.begin(), M)) {
+ StringRef ExtensionName = Token.substr(1);
+ auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
+ if (NameValuePair == SPIRVExtensionMap.end())
+ return O.error("Unknown SPIR-V extension: " + Token.str());
+
+ EnabledExtensions.insert(NameValuePair->second);
+ }
+
+ for (auto &&Token : make_range(M, Tokens.end())) {
+ if (Token == "all")
continue;
- }
if (Token.size() == 3 && Token.upper() == "KHR") {
for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
if (StringRef(ExtensionName).starts_with("SPV_KHR_"))
- EnabledExtensions.insert(ExtensionEnum);
+ Vals.insert(ExtensionEnum);
continue;
}
if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-")))
- return O.error("Invalid extension list format: " + Token.str());
+ return O.error("Invalid extension list format: " + Token);
- StringRef ExtensionName = Token.substr(1);
- auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
+ auto NameValuePair = SPIRVExtensionMap.find(Token.substr(1));
- if (NameValuePair == SPIRVExtensionMap.end())
+ if (NameValuePair == SPIRVExtensionMap.cend())
return O.error("Unknown SPIR-V extension: " + Token.str());
+ if (EnabledExtensions.count(NameValuePair->second))
+ return O.error(
+ "Extension cannot be allowed and disallowed at the same time: " +
+ NameValuePair->first);
- if (Token.starts_with("+")) {
- EnabledExtensions.insert(NameValuePair->second);
- } else if (EnabledExtensions.count(NameValuePair->second)) {
- if (llvm::is_contained(Tokens, "+" + ExtensionName.str()))
- return O.error(
- "Extension cannot be allowed and disallowed at the same time: " +
- ExtensionName.str());
-
- EnabledExtensions.erase(NameValuePair->second);
- }
+ Vals.erase(NameValuePair->second);
}
- Vals = std::move(EnabledExtensions);
+ Vals.insert(EnabledExtensions.cbegin(), EnabledExtensions.cend());
+
return false;
}
|
This changes the extension parsing mechanism underpinning
--spirv-extto be more explicit about what it is doing and not rely on a sort. More specifically, we partition extensions into enabled (prefixed with+) and others, and then individually handle the resulting ranges.