From 62ec0f744eae409ac92d9f83b89586a2f4b44810 Mon Sep 17 00:00:00 2001 From: Bryan Bernhart Date: Tue, 28 Jun 2022 11:26:08 -0700 Subject: [PATCH] Fix overflow in SetMaxResource* related functions. If the number of bits exceeds UINT32_MAX bits, SetMaxResource* would overflow. --- src/gpgmm/d3d12/CapsD3D12.cpp | 19 +++++++++++-------- src/gpgmm/utils/Limits.h | 7 +++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/gpgmm/d3d12/CapsD3D12.cpp b/src/gpgmm/d3d12/CapsD3D12.cpp index da91d713c..ea4f65dd0 100644 --- a/src/gpgmm/d3d12/CapsD3D12.cpp +++ b/src/gpgmm/d3d12/CapsD3D12.cpp @@ -15,6 +15,7 @@ #include "gpgmm/d3d12/CapsD3D12.h" #include "gpgmm/d3d12/ErrorD3D12.h" +#include "gpgmm/utils/Limits.h" #include @@ -25,12 +26,13 @@ namespace gpgmm::d3d12 { ReturnIfFailed( device->CheckFeatureSupport(D3D12_FEATURE_GPU_VIRTUAL_ADDRESS_SUPPORT, &feature, sizeof(D3D12_FEATURE_DATA_GPU_VIRTUAL_ADDRESS_SUPPORT))); - // Prevent possible overflow. - if (feature.MaxGPUVirtualAddressBitsPerResource == 0) { - return E_INVALIDARG; + // Check for overflow. + if (feature.MaxGPUVirtualAddressBitsPerResource == 0 || + feature.MaxGPUVirtualAddressBitsPerResource > GetNumOfBits()) { + return E_FAIL; } - *sizeOut = (1 << (feature.MaxGPUVirtualAddressBitsPerResource - 1)) - 1; + *sizeOut = (1ull << (feature.MaxGPUVirtualAddressBitsPerResource - 1)) - 1; return S_OK; } @@ -39,12 +41,13 @@ namespace gpgmm::d3d12 { ReturnIfFailed( device->CheckFeatureSupport(D3D12_FEATURE_GPU_VIRTUAL_ADDRESS_SUPPORT, &feature, sizeof(D3D12_FEATURE_DATA_GPU_VIRTUAL_ADDRESS_SUPPORT))); - // Prevent possible overflow. - if (feature.MaxGPUVirtualAddressBitsPerResource == 0) { - return E_INVALIDARG; + // Check for overflow. + if (feature.MaxGPUVirtualAddressBitsPerResource == 0 || + feature.MaxGPUVirtualAddressBitsPerResource > GetNumOfBits()) { + return E_FAIL; } - *sizeOut = (1 << (feature.MaxGPUVirtualAddressBitsPerResource - 1)) - 1; + *sizeOut = (1ull << (feature.MaxGPUVirtualAddressBitsPerResource - 1)) - 1; return S_OK; } diff --git a/src/gpgmm/utils/Limits.h b/src/gpgmm/utils/Limits.h index ae8882569..23011d394 100644 --- a/src/gpgmm/utils/Limits.h +++ b/src/gpgmm/utils/Limits.h @@ -15,6 +15,7 @@ #ifndef GPGMM_UTILS_LIMITS_H_ #define GPGMM_UTILS_LIMITS_H_ +#include // CHAR_BIT #include #include @@ -24,6 +25,12 @@ namespace gpgmm { static constexpr uint64_t kInvalidSize = std::numeric_limits::max(); static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); + template + constexpr size_t GetNumOfBits() { + static_assert(CHAR_BIT == 8, "Size of a char is not 8 bits."); + return sizeof(T) * CHAR_BIT; + } + } // namespace gpgmm #endif // GPGMM_UTILS_LIMITS_H_