diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 2b375b93e3a48..a662b72c2a362 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1289,8 +1289,8 @@ bool SemaHLSL::handleRootSignatureElements( VerifyRegister(Loc, Descriptor->Reg.Number); VerifySpace(Loc, Descriptor->Space); - if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag( - Version, llvm::to_underlying(Descriptor->Flags))) + if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version, + Descriptor->Flags)) ReportFlagError(Loc); } else if (const auto *Constants = std::get_if(&Elem)) { diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h index 0b5646229e8b5..b9a08ce1ca14e 100644 --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -248,6 +248,12 @@ enum class StaticBorderColor : uint32_t { bool isValidBorderColor(uint32_t V); +bool isValidRootDesciptorFlags(uint32_t V); + +bool isValidDescriptorRangeFlags(uint32_t V); + +bool isValidStaticSamplerFlags(uint32_t V); + LLVM_ABI ArrayRef> getStaticBorderColors(); LLVM_ABI PartType parsePartType(StringRef S); diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 4dd18111b0c9d..7131980e9ff3a 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -28,12 +28,14 @@ LLVM_ABI bool verifyRootFlag(uint32_t Flags); LLVM_ABI bool verifyVersion(uint32_t Version); LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue); LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace); -LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal); +LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, + dxbc::RootDescriptorFlags Flags); LLVM_ABI bool verifyRangeType(uint32_t Type); LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, - dxbc::DescriptorRangeFlags FlagsVal); -LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber); + dxbc::DescriptorRangeFlags Flags); +LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, + dxbc::StaticSamplerFlags Flags); LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors); LLVM_ABI bool verifyMipLODBias(float MipLODBias); LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy); diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp index b334f86568acb..22f518067b318 100644 --- a/llvm/lib/BinaryFormat/DXContainer.cpp +++ b/llvm/lib/BinaryFormat/DXContainer.cpp @@ -82,6 +82,27 @@ bool llvm::dxbc::isValidBorderColor(uint32_t V) { return false; } +bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) { + using FlagT = dxbc::RootDescriptorFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) { + using FlagT = dxbc::DescriptorRangeFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) { + using FlagT = dxbc::StaticSamplerFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + dxbc::PartType dxbc::parsePartType(StringRef S) { #define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName) return StringSwitch(S) diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 7a0cf408968de..707f0c368e9d8 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature( "RegisterSpace", Descriptor.RegisterSpace)); if (RSD.Version > 1) { - if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) + bool IsValidFlag = + dxbc::isValidRootDesciptorFlags(Descriptor.Flags) && + hlsl::rootsig::verifyRootDescriptorFlag( + RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( @@ -676,9 +679,11 @@ Error MetadataParser::validateRootSignature( make_error>( "NumDescriptors", Range.NumDescriptors)); - if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, - dxbc::DescriptorRangeFlags(Range.Flags))) + bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) && + hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, + dxbc::DescriptorRangeFlags(Range.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( @@ -731,8 +736,11 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error>( "RegisterSpace", Sampler.RegisterSpace)); - - if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags)) + bool IsValidFlag = + dxbc::isValidStaticSamplerFlags(Sampler.Flags) && + hlsl::rootsig::verifyStaticSamplerFlags( + RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 8a2b03d9ede8b..30408dfda940d 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) { return !(RegisterSpace >= 0xFFFFFFF0); } -bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { +bool verifyRootDescriptorFlag(uint32_t Version, + dxbc::RootDescriptorFlags FlagsVal) { using FlagT = dxbc::RootDescriptorFlags; FlagT Flags = FlagT(FlagsVal); if (Version == 1) @@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, dxbc::DescriptorRangeFlags Flags) { using FlagT = dxbc::DescriptorRangeFlags; - const bool IsSampler = (Type == dxil::ResourceClass::Sampler); if (Version == 1) { @@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, return (Flags & ~Mask) == FlagT::None; } -bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) { - uint32_t LargestValue = llvm::to_underlying( - dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR); - if (FlagsNumber >= NextPowerOf2(LargestValue)) - return false; - - dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber); +bool verifyStaticSamplerFlags(uint32_t Version, + dxbc::StaticSamplerFlags Flags) { if (Version <= 2) return Flags == dxbc::StaticSamplerFlags::None; diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll new file mode 100644 index 0000000000000..c27c87ff057d5 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll @@ -0,0 +1,20 @@ +; RUN: not opt -passes='print' %s -S -o - 2>&1 | FileCheck %s + +target triple = "dxil-unknown-shadermodel6.0-compute" + +; CHECK: error: Invalid value for DescriptorFlag: 66666 +; CHECK-NOT: Root Signature Definitions + +define void @main() #0 { +entry: + ret void +} +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + + +!dx.rootsignatures = !{!2} ; list of function/root signature pairs +!2 = !{ ptr @main, !3, i32 2 } ; function, root signature +!3 = !{ !5 } ; list of root signature elements +!5 = !{ !"DescriptorTable", i32 0, !6, !7 } +!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 } +!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 } diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll new file mode 100644 index 0000000000000..898e197c7e0cc --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll @@ -0,0 +1,18 @@ +; RUN: not opt -passes='print' %s -S -o - 2>&1 | FileCheck %s + +target triple = "dxil-unknown-shadermodel6.0-compute" + + +; CHECK: error: Invalid value for RootDescriptorFlag: 666 +; CHECK-NOT: Root Signature Definitions +define void @main() #0 { +entry: + ret void +} +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + + +!dx.rootsignatures = !{!2} ; list of function/root signature pairs +!2 = !{ ptr @main, !3, i32 2 } ; function, root signature +!3 = !{ !5 } ; list of root signature elements +!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666 }