Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 30 additions & 33 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "llvm/TargetParser/Triple.h"

#include <functional>
#include <iterator>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -26,7 +28,7 @@

using namespace llvm;

static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
static const std::map<StringRef, SPIRV::Extension::Extension>
SPIRVExtensionMap = {
{"SPV_EXT_shader_atomic_float_add",
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add},
Expand Down Expand Up @@ -181,57 +183,52 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
std::set<SPIRV::Extension::Extension> &Vals) {
SmallVector<StringRef, 10> Tokens;
ArgValue.split(Tokens, ",", -1, false);
llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
// We want to ensure that we handle "all" first, to ensure that any
// subsequent disablement actually behaves as expected i.e. given
// --spv-ext=all,-foo, we first enable all and then disable foo; this should
// be revisited and simplified.
if (LHS == "all")
return true;
if (RHS == "all")
return false;
return !(RHS < LHS);
});

std::set<SPIRV::Extension::Extension> EnabledExtensions;

for (const auto &Token : Tokens) {
if (Token == "all") {
for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
EnabledExtensions.insert(ExtensionEnum);
auto M = partition(Tokens, [](auto &&T) { return T.starts_with('+'); });

if (std::any_of(M, Tokens.end(), [](auto &&T) { return T == "all"; }))
copy(make_second_range(SPIRVExtensionMap), std::inserter(Vals, Vals.end()));

for (auto &&Token : make_range(Tokens.begin(), M)) {
StringRef ExtensionName = Token.substr(1);
auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);

if (NameValuePair == SPIRVExtensionMap.end())
return O.error("Unknown SPIR-V extension: " + Token.str());

EnabledExtensions.insert(NameValuePair->second);
}

for (auto &&Token : make_range(M, Tokens.end())) {
if (Token == "all")
continue;
}

if (Token.size() == 3 && Token.upper() == "KHR") {
for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
if (StringRef(ExtensionName).starts_with("SPV_KHR_"))
EnabledExtensions.insert(ExtensionEnum);
Vals.insert(ExtensionEnum);
continue;
}

if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-")))
return O.error("Invalid extension list format: " + Token.str());
return O.error("Invalid extension list format: " + Token);

StringRef ExtensionName = Token.substr(1);
auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
auto NameValuePair = SPIRVExtensionMap.find(Token.substr(1));

if (NameValuePair == SPIRVExtensionMap.end())
if (NameValuePair == SPIRVExtensionMap.cend())
return O.error("Unknown SPIR-V extension: " + Token.str());
if (EnabledExtensions.count(NameValuePair->second))
return O.error(
"Extension cannot be allowed and disallowed at the same time: " +
NameValuePair->first);

if (Token.starts_with("+")) {
EnabledExtensions.insert(NameValuePair->second);
} else if (EnabledExtensions.count(NameValuePair->second)) {
if (llvm::is_contained(Tokens, "+" + ExtensionName.str()))
return O.error(
"Extension cannot be allowed and disallowed at the same time: " +
ExtensionName.str());

EnabledExtensions.erase(NameValuePair->second);
}
Vals.erase(NameValuePair->second);
}

Vals = std::move(EnabledExtensions);
Vals.insert(EnabledExtensions.cbegin(), EnabledExtensions.cend());

return false;
}

Expand Down