-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[HLSL] Update Frontend to support version 1.2 of root signature #160616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…llvm-project into obj2yaml/root-signature-1.2
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang Author: None (joaosaffran) ChangesThis patch updates the frontend to support version 1.2 of root signatures, it adds parsing, metadata generation and a few tests. Full diff: https://github.com/llvm/llvm-project/pull/160616.diff 10 Files Affected:
diff --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
index a5cfeb34b2b51..1d7f7adbe076f 100644
--- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
@@ -65,6 +65,9 @@
#ifndef STATIC_BORDER_COLOR_ENUM
#define STATIC_BORDER_COLOR_ENUM(NAME, LIT) ENUM(NAME, LIT)
#endif
+#ifndef STATIC_SAMPLER_FLAG_ENUM
+#define STATIC_SAMPLER_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT)
+#endif
// General Tokens:
TOK(invalid, "invalid identifier")
@@ -228,6 +231,10 @@ STATIC_BORDER_COLOR_ENUM(OpaqueWhite, "STATIC_BORDER_COLOR_OPAQUE_WHITE")
STATIC_BORDER_COLOR_ENUM(OpaqueBlackUint, "STATIC_BORDER_COLOR_OPAQUE_BLACK_UINT")
STATIC_BORDER_COLOR_ENUM(OpaqueWhiteUint, "STATIC_BORDER_COLOR_OPAQUE_WHITE_UINT")
+// Root Descriptor Flag Enums:
+STATIC_SAMPLER_FLAG_ENUM(UintBorderColor, "UINT_BORDER_COLOR")
+STATIC_SAMPLER_FLAG_ENUM(NonNormalizedCoordinates, "NON_NORMALIZED_COORDINATES")
+
#undef STATIC_BORDER_COLOR_ENUM
#undef COMPARISON_FUNC_ENUM
#undef TEXTURE_ADDRESS_MODE_ENUM
@@ -237,6 +244,7 @@ STATIC_BORDER_COLOR_ENUM(OpaqueWhiteUint, "STATIC_BORDER_COLOR_OPAQUE_WHITE_UINT
#undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
#undef DESCRIPTOR_RANGE_FLAG_ENUM_ON
#undef ROOT_DESCRIPTOR_FLAG_ENUM
+#undef STATIC_SAMPLER_FLAG_ENUM
#undef ROOT_FLAG_ENUM
#undef DESCRIPTOR_RANGE_OFFSET_ENUM
#undef UNBOUNDED_ENUM
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index b06846fd83c09..8f91d7cd7b031 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -130,6 +130,7 @@ class RootSignatureParser {
std::optional<float> MaxLOD;
std::optional<uint32_t> Space;
std::optional<llvm::dxbc::ShaderVisibility> Visibility;
+ std::optional<llvm::dxbc::StaticSamplerFlags> Flags;
};
std::optional<ParsedStaticSamplerParams> parseStaticSamplerParams();
@@ -153,6 +154,8 @@ class RootSignatureParser {
parseRootDescriptorFlags(RootSignatureToken::Kind Context);
std::optional<llvm::dxbc::DescriptorRangeFlags>
parseDescriptorRangeFlags(RootSignatureToken::Kind Context);
+ std::optional<llvm::dxbc::StaticSamplerFlags>
+ parseStaticSamplerFlags(RootSignatureToken::Kind Context);
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
/// 32-bit integer
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 3b16efb1f1199..5677365688413 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -485,6 +485,9 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
if (Params->Visibility.has_value())
Sampler.Visibility = Params->Visibility.value();
+ if (Params->Flags.has_value())
+ Sampler.Flags = Params->Flags.value();
+
return Sampler;
}
@@ -926,6 +929,20 @@ RootSignatureParser::parseStaticSamplerParams() {
if (!Visibility.has_value())
return std::nullopt;
Params.Visibility = Visibility;
+ } else if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+ // `flags` `=` UINT_BORDER_COLOR
+ if (Params.Flags.has_value()) {
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
+ return std::nullopt;
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_equal))
+ return std::nullopt;
+
+ auto Flags = parseStaticSamplerFlags(TokenKind::kw_flags);
+ if (!Flags.has_value())
+ return std::nullopt;
+ Params.Flags = Flags;
} else {
consumeNextToken(); // let diagnostic be at the start of invalid token
reportDiag(diag::err_hlsl_invalid_token)
@@ -1255,6 +1272,50 @@ RootSignatureParser::parseDescriptorRangeFlags(TokenKind Context) {
return Flags;
}
+std::optional<llvm::dxbc::StaticSamplerFlags>
+RootSignatureParser::parseStaticSamplerFlags(TokenKind Context) {
+ assert(CurToken.TokKind == TokenKind::pu_equal &&
+ "Expects to only be invoked starting at given keyword");
+
+ // Handle the edge-case of '0' to specify no flags set
+ if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+ if (!verifyZeroFlag()) {
+ reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
+ return std::nullopt;
+ }
+ return llvm::dxbc::StaticSamplerFlags::None;
+ }
+
+ TokenKind Expected[] = {
+#define STATIC_SAMPLER_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ std::optional<llvm::dxbc::StaticSamplerFlags> Flags;
+
+ do {
+ if (tryConsumeExpectedToken(Expected)) {
+ switch (CurToken.TokKind) {
+#define STATIC_SAMPLER_FLAG_ENUM(NAME, LIT) \
+ case TokenKind::en_##NAME: \
+ Flags = maybeOrFlag<llvm::dxbc::StaticSamplerFlags>( \
+ Flags, llvm::dxbc::StaticSamplerFlags::NAME); \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ default:
+ llvm_unreachable("Switch for consumed enum token was not provided");
+ }
+ } else {
+ consumeNextToken(); // consume token to point at invalid token
+ reportDiag(diag::err_hlsl_invalid_token)
+ << /*value=*/1 << /*value of*/ Context;
+ return std::nullopt;
+ }
+ } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+ return Flags;
+}
+
std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(
diff --git a/clang/test/CodeGenHLSL/RootSignature.hlsl b/clang/test/CodeGenHLSL/RootSignature.hlsl
index bc40bdd79ce59..eaff3a9e73305 100644
--- a/clang/test/CodeGenHLSL/RootSignature.hlsl
+++ b/clang/test/CodeGenHLSL/RootSignature.hlsl
@@ -82,8 +82,8 @@ void RootDescriptorsEntry() {}
// checking minLOD, maxLOD
// CHECK-SAME: float -1.280000e+02, float 1.280000e+02,
-// checking register, space and visibility
-// CHECK-SAME: i32 42, i32 0, i32 0}
+// checking register, space, visibility and flag
+// CHECK-SAME: i32 42, i32 0, i32 0, i32 1}
#define SampleStaticSampler \
"StaticSampler(s42, " \
@@ -96,6 +96,7 @@ void RootDescriptorsEntry() {}
" borderColor = STATIC_BORDER_COLOR_OPAQUE_WHITE, " \
" minLOD = -128.f, maxLOD = 128.f, " \
" space = 0, visibility = SHADER_VISIBILITY_ALL, " \
+ " flags = UINT_BORDER_COLOR" \
")"
[shader("compute"), RootSignature(SampleStaticSampler)]
[numthreads(1,1,1)]
diff --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
index 01f8d4f97b092..82f19686167da 100644
--- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
@@ -226,6 +226,9 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
STATIC_BORDER_COLOR_OPAQUE_WHITE
STATIC_BORDER_COLOR_OPAQUE_BLACK_UINT
STATIC_BORDER_COLOR_OPAQUE_WHITE_UINT
+
+ UINT_BORDER_COLOR
+ NON_NORMALIZED_COORDINATES
)cc";
hlsl::RootSignatureLexer Lexer(Source);
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 9b9f5dd8a63bb..f7e9d2d32c3f4 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -263,7 +263,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseStaticSamplerTest) {
filter = FILTER_MAXIMUM_MIN_POINT_MAG_LINEAR_MIP_POINT,
maxLOD = 9000, addressU = TEXTURE_ADDRESS_MIRROR,
comparisonFunc = COMPARISON_NOT_EQUAL,
- borderColor = STATIC_BORDER_COLOR_OPAQUE_BLACK_UINT
+ borderColor = STATIC_BORDER_COLOR_OPAQUE_BLACK_UINT,
+ flags = 0
)
)cc";
@@ -336,6 +337,37 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseStaticSamplerTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidStaticSamplerFlagsTest) {
+ const llvm::StringLiteral Source = R"cc(
+ StaticSampler(s0, flags = UINT_BORDER_COLOR | NON_NORMALIZED_COORDINATES)
+ )cc";
+
+ auto Ctx = createMinimalASTContext();
+ StringLiteral *Signature = wrapSource(Ctx, Source);
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+
+ hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ auto Elements = Parser.getElements();
+ ASSERT_EQ(Elements.size(), 1u);
+
+ RootElement Elem = Elements[0].getElement();
+ ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
+ auto ValidStaticSamplerFlags =
+ llvm::dxbc::StaticSamplerFlags::NonNormalizedCoordinates |
+ llvm::dxbc::StaticSamplerFlags::UintBorderColor;
+ ASSERT_EQ(std::get<StaticSampler>(Elem).Flags, ValidStaticSamplerFlags);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidParseFloatsTest) {
const llvm::StringLiteral Source = R"cc(
StaticSampler(s0, mipLODBias = 0),
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 87777fddc9157..37224d8a94527 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -131,6 +131,7 @@ struct StaticSampler {
float MaxLOD = std::numeric_limits<float>::max();
uint32_t Space = 0;
dxbc::ShaderVisibility Visibility = dxbc::ShaderVisibility::All;
+ dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags::None;
};
/// Models RootElement : RootFlags | RootConstants | RootParam
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
index 92c62b83fadb0..f9129adb4a4f9 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
@@ -172,7 +172,7 @@ raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
<< ", borderColor = " << Sampler.BorderColor
<< ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
<< ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility
- << ")";
+ << ", flags = " << Sampler.Flags << ")";
return OS;
}
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index f29f2c7602fc6..e2a1f242bfc8e 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -212,6 +212,7 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Sampler.Visibility))),
+ ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Flags))),
};
return MDNode::get(Ctx, Operands);
}
diff --git a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
index 1eb03f16527ec..abdd8a6a21112 100644
--- a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
+++ b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
@@ -266,7 +266,8 @@ TEST(HLSLRootSignatureTest, DefaultStaticSamplerDump) {
"minLOD = 0.000000e+00, "
"maxLOD = 3.402823e+38, "
"space = 0, "
- "visibility = All"
+ "visibility = All, "
+ "flags = 0x0"
")";
EXPECT_EQ(Out, Expected);
}
@@ -287,6 +288,7 @@ TEST(HLSLRootSignatureTest, DefinedStaticSamplerDump) {
Sampler.MaxLOD = 32.0f;
Sampler.Space = 7;
Sampler.Visibility = llvm::dxbc::ShaderVisibility::Domain;
+ Sampler.Flags = llvm::dxbc::StaticSamplerFlags::NonNormalizedCoordinates;
std::string Out;
llvm::raw_string_ostream OS(Out);
@@ -305,7 +307,8 @@ TEST(HLSLRootSignatureTest, DefinedStaticSamplerDump) {
"minLOD = 1.000000e+00, "
"maxLOD = 3.200000e+01, "
"space = 7, "
- "visibility = Domain"
+ "visibility = Domain, "
+ "flags = 0x2"
")";
EXPECT_EQ(Out, Expected);
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically LGTM, just missing some pretty printing and a test-case
"space = 0, " | ||
"visibility = All" | ||
"visibility = All, " | ||
"flags = 0x0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These flags should be output with their named values (and None
instead of 0
) to match the behaviour of the other flag types
return std::nullopt; | ||
Params.Visibility = Visibility; | ||
} else if (tryConsumeExpectedToken(TokenKind::kw_flags)) { | ||
// `flags` `=` UINT_BORDER_COLOR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
// `flags` `=` UINT_BORDER_COLOR | |
// `flags` `=` STATIC_SAMPLE_FLAGS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also update https://github.com/llvm/llvm-project/blob/main/clang/test/AST/HLSL/RootSignatures-AST.hlsl to ensure the flags field is generated in the ast
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp,h -- clang/include/clang/Basic/LangOptions.h clang/include/clang/Parse/ParseHLSLRootSignature.h clang/lib/AST/TextNodeDumper.cpp clang/lib/Driver/ToolChains/HLSL.cpp clang/lib/Parse/ParseHLSLRootSignature.cpp clang/unittests/Lex/LexHLSLRootSignatureTest.cpp clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp llvm/include/llvm/BinaryFormat/DXContainer.h llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h llvm/include/llvm/MC/DXContainerRootSignature.h llvm/include/llvm/Object/DXContainer.h llvm/include/llvm/ObjectYAML/DXContainerYAML.h llvm/lib/BinaryFormat/DXContainer.cpp llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp llvm/lib/MC/DXContainerRootSignature.cpp llvm/lib/Object/DXContainer.cpp llvm/lib/ObjectYAML/DXContainerEmitter.cpp llvm/lib/ObjectYAML/DXContainerYAML.cpp llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp llvm/unittests/Object/DXContainerTest.cpp llvm/unittests/ObjectYAML/DXContainerYAMLTest.cpp
View the diff from clang-format here.diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index c5d89106a..edee6a7de 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -56,7 +56,8 @@ struct RootDescriptor {
return;
}
- assert((Version == llvm::dxbc::RootSignatureVersion::V1_1 || Version == llvm::dxbc::RootSignatureVersion::V1_2) &&
+ assert((Version == llvm::dxbc::RootSignatureVersion::V1_1 ||
+ Version == llvm::dxbc::RootSignatureVersion::V1_2) &&
"Specified an invalid root signature version");
switch (Type) {
case dxil::ResourceClass::CBuffer:
@@ -100,7 +101,8 @@ struct DescriptorTableClause {
return;
}
- assert((Version == dxbc::RootSignatureVersion::V1_1 || Version == dxbc::RootSignatureVersion::V1_2) &&
+ assert((Version == dxbc::RootSignatureVersion::V1_1 ||
+ Version == dxbc::RootSignatureVersion::V1_2) &&
"Specified an invalid root signature version");
switch (Type) {
case dxil::ResourceClass::CBuffer:
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 411838b9b..b8bafd688 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -40,7 +40,8 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
if (Version == 1)
return Flags == FlagT::DataVolatile;
- assert((Version == 2 || Version == 3) && "Provided invalid root signature version");
+ assert((Version == 2 || Version == 3) &&
+ "Provided invalid root signature version");
// The data-specific flags are mutually exclusive.
FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
|
This patch updates the frontend to support version 1.2 of root signatures, it adds parsing, metadata generation and a few tests.