diff --git a/tools/clang/unittests/HLSLExec/LongVectorTestData.h b/tools/clang/unittests/HLSLExec/LongVectorTestData.h index de88991db1..2297e6eb3a 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorTestData.h +++ b/tools/clang/unittests/HLSLExec/LongVectorTestData.h @@ -2,8 +2,10 @@ #define LONGVECTORTESTDATA_H #include + #include #include +#include #include #include diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp index ba1b215163..6d5fc81ea7 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.cpp +++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp @@ -1,6 +1,18 @@ #include "LongVectors.h" +#include "LongVectorTestData.h" + +#include "ShaderOpTest.h" +#include "dxc/Support/Global.h" + #include "HlslExecTestUtils.h" +#include "TableParameterHandler.h" + +#include #include +#include +#include +#include +#include namespace LongVector { @@ -27,6 +39,33 @@ getOpType(const OpTypeMetaData (&Values)[Length], std::abort(); } +template +OpTypeMetaData +getOpTypeMetaData(const OpTypeMetaData (&Values)[N], OP_TYPE OpType) { + for (size_t I = 0; I < N; ++I) { + if (Values[I].OpType == OpType) + return Values[I]; + } + + DXASSERT(false, "Missing OpType metadata"); + std::abort(); +} + +template +OpTypeMetaData getOpTypeMetaData(OP_TYPE OpType); + +#define OP_TYPE_META_DATA(TYPE, ARRAY) \ + template <> OpTypeMetaData getOpTypeMetaData(TYPE OpType) { \ + return getOpTypeMetaData(ARRAY, OpType); \ + } + +OP_TYPE_META_DATA(UnaryOpType, unaryOpTypeStringToOpMetaData); +OP_TYPE_META_DATA(AsTypeOpType, asTypeOpTypeStringToOpMetaData); +OP_TYPE_META_DATA(TrigonometricOpType, trigonometricOpTypeStringToOpMetaData); +OP_TYPE_META_DATA(UnaryMathOpType, unaryMathOpTypeStringToOpMetaData); +OP_TYPE_META_DATA(BinaryMathOpType, binaryMathOpTypeStringToOpMetaData); +OP_TYPE_META_DATA(TernaryMathOpType, ternaryMathOpTypeStringToOpMetaData); + // Helper to fill the test data from the shader buffer based on type. Convenient // to be used when copying HLSL*_t types so we can use the underlying type. template @@ -157,19 +196,6 @@ bool doVectorsMatch(const std::vector &ActualValues, return false; } -// A helper to create and fill the expected vector with computed values. -// Also helps us factor out the generic fill loop via a passed in ComputeFn. -template -VariantVector generateExpectedVector(size_t Count, ComputeFnT ComputeFn) { - - std::vector Values; - - for (size_t Index = 0; Index < Count; ++Index) - Values.push_back(ComputeFn(Index)); - - return Values; -} - template void logLongVector(const std::vector &Values, const std::wstring &Name) { WEX::Logging::Log::Comment( @@ -219,41 +245,6 @@ template std::string getHLSLTypeString() { return "UnknownType"; } -// These are helper arrays to be used with the TableParameterHandler that parses -// the LongVectorOpTable.xml file for us. -static TableParameter UnaryOpParameters[] = { - {L"DataType", TableParameter::STRING, true}, - {L"OpTypeEnum", TableParameter::STRING, true}, - {L"InputValueSetName1", TableParameter::STRING, false}, -}; - -static TableParameter BinaryOpParameters[] = { - {L"DataType", TableParameter::STRING, true}, - {L"OpTypeEnum", TableParameter::STRING, true}, - {L"InputValueSetName1", TableParameter::STRING, false}, - {L"InputValueSetName2", TableParameter::STRING, false}, - {L"ScalarInputFlags", TableParameter::STRING, false}, -}; - -static TableParameter TernaryOpParameters[] = { - {L"DataType", TableParameter::STRING, true}, - {L"OpTypeEnum", TableParameter::STRING, true}, - {L"InputValueSetName1", TableParameter::STRING, false}, - {L"InputValueSetName2", TableParameter::STRING, false}, - {L"InputValueSetName3", TableParameter::STRING, false}, - {L"ScalarInputFlags", TableParameter::STRING, false}, -}; - -static TableParameter AsTypeOpParameters[] = { - // DataTypeOut is determined at runtime based on the OpType. - // For example...AsUint has an output type of uint32_t. - {L"DataTypeIn", TableParameter::STRING, true}, - {L"OpTypeEnum", TableParameter::STRING, true}, - {L"InputValueSetName1", TableParameter::STRING, false}, - {L"InputValueSetName2", TableParameter::STRING, false}, - {L"ScalarInputFlags", TableParameter::STRING, false}, -}; - bool OpTest::classSetup() { // Run this only once. if (!Initialized) { @@ -304,265 +295,91 @@ bool OpTest::classSetup() { return true; } -TEST_F(OpTest, trigonometricOpTest) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - - const size_t TableSize = sizeof(UnaryOpParameters) / sizeof(TableParameter); - TableParameterHandler Handler(UnaryOpParameters, TableSize); - - std::wstring DataType(Handler.GetTableParamByName(L"DataType")->m_str); - std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str); - - auto OpTypeMD = getTrigonometricOpType(OpTypeString); - dispatchTrigonometricOpTestByDataType(OpTypeMD, DataType, Handler); -} - -TEST_F(OpTest, unaryOpTest) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - - const size_t TableSize = sizeof(UnaryOpParameters) / sizeof(TableParameter); - TableParameterHandler Handler(UnaryOpParameters, TableSize); - - std::wstring DataType(Handler.GetTableParamByName(L"DataType")->m_str); - std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str); +static uint16_t GetScalarInputFlags() { + using WEX::Common::String; + using WEX::TestExecution::TestData; - auto OpTypeMD = getUnaryOpType(OpTypeString); - dispatchTestByDataType(OpTypeMD, DataType, Handler); -} - -TEST_F(OpTest, asTypeOpTest) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - - const size_t TableSize = sizeof(AsTypeOpParameters) / sizeof(TableParameter); - TableParameterHandler Handler(AsTypeOpParameters, TableSize); - - std::wstring DataTypeIn(Handler.GetTableParamByName(L"DataTypeIn")->m_str); - std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str); + String ScalarInputFlagsString; + if (FAILED( + TestData::TryGetValue(L"ScalarInputFlags", ScalarInputFlagsString))) + return 0; - auto OpTypeMD = getAsTypeOpType(OpTypeString); - std::wstring ScalarInputFlags( - Handler.GetTableParamByName(L"ScalarInputFlags")->m_str); - if (!ScalarInputFlags.empty()) - VERIFY_IS_TRUE( - IsHexString(ScalarInputFlags.c_str(), &OpTypeMD.ScalarInputFlags), - L"ScalarInputFlags must be a hex string if provided."); + if (ScalarInputFlagsString.IsEmpty()) + return 0; - dispatchTestByDataType(OpTypeMD, DataTypeIn, Handler); + uint16_t ScalarInputFlags; + VERIFY_IS_TRUE(IsHexString(ScalarInputFlagsString, &ScalarInputFlags)); + return ScalarInputFlags; } -TEST_F(OpTest, unaryMathOpTest) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); +static WEX::Common::String getInputValueSetName(size_t Index) { + using WEX::Common::String; + using WEX::TestExecution::TestData; - const int TableSize = sizeof(UnaryOpParameters) / sizeof(TableParameter); - TableParameterHandler Handler(UnaryOpParameters, TableSize); + DXASSERT(Index >= 0 && Index <= 9, "Only single digit indices supported"); - std::wstring DataTypeIn(Handler.GetTableParamByName(L"DataType")->m_str); - std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str); + String ParameterName = L"InputValueSetName"; + ParameterName.Append((wchar_t)(L'1' + Index)); - auto OpTypeMD = getUnaryMathOpType(OpTypeString); - dispatchUnaryMathOpTestByDataType(OpTypeMD, DataTypeIn, Handler); -} - -TEST_F(OpTest, binaryMathOpTest) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - - using namespace WEX::Common; - - const size_t TableSize = sizeof(BinaryOpParameters) / sizeof(TableParameter); - TableParameterHandler Handler(BinaryOpParameters, TableSize); - - std::wstring DataType(Handler.GetTableParamByName(L"DataType")->m_str); - std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str); - - auto OpTypeMD = getBinaryMathOpType(OpTypeString); - - std::wstring ScalarInputFlags( - Handler.GetTableParamByName(L"ScalarInputFlags")->m_str); - if (!ScalarInputFlags.empty()) - VERIFY_IS_TRUE( - IsHexString(ScalarInputFlags.c_str(), &OpTypeMD.ScalarInputFlags), - L"ScalarInputFlags must be a hex string if provided."); + String ValueSetName; + if (FAILED(TestData::TryGetValue(ParameterName, ValueSetName))) { + String Name = L"DefaultInputValueSet"; + Name.Append((wchar_t)(L'1' + Index)); + return Name; + } - dispatchTestByDataType(OpTypeMD, DataType, Handler); + return ValueSetName; } -TEST_F(OpTest, ternaryMathOpTest) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); +struct TestConfig { + using String = WEX::Common::String; - const size_t TableSize = sizeof(TernaryOpParameters) / sizeof(TableParameter); - TableParameterHandler Handler(TernaryOpParameters, TableSize); + String DataType; + String OpTypeEnum; + String InputValueSetNames[3]; + uint16_t ScalarInputFlags = 0; + size_t LongVectorInputSize = 0; + bool VerboseLogging = false; - std::wstring DataType(Handler.GetTableParamByName(L"DataType")->m_str); - std::wstring OpTypeString(Handler.GetTableParamByName(L"OpTypeEnum")->m_str); + static std::optional Create(bool VerboseLogging) { + using WEX::TestExecution::RuntimeParameters; + using WEX::TestExecution::TestData; - auto OpTypeMD = getTernaryMathOpType(OpTypeString); + TestConfig Values; - std::wstring ScalarInputFlags( - Handler.GetTableParamByName(L"ScalarInputFlags")->m_str); - if (!ScalarInputFlags.empty()) - VERIFY_IS_TRUE( - IsHexString(ScalarInputFlags.c_str(), &OpTypeMD.ScalarInputFlags), - L"ScalarInputFlags must be a hex string if provided."); + if (FAILED(TestData::TryGetValue(L"DataType", Values.DataType)) && + FAILED(TestData::TryGetValue(L"DataTypeIn", Values.DataType))) { + LOG_ERROR_FMT_THROW(L"TestData missing 'DataType' or 'DataTypeIn'."); + return std::nullopt; + } - dispatchTestByDataType(OpTypeMD, DataType, Handler); -} + if (FAILED(TestData::TryGetValue(L"OpTypeEnum", Values.OpTypeEnum))) { + LOG_ERROR_FMT_THROW(L"TestData missing 'OpTypeEnum'."); + return std::nullopt; + } -// Generic dispatch that dispatchs all DataTypes recognized in these tests -template -void OpTest::dispatchTestByDataType(const OpTypeMetaData &OpTypeMd, - std::wstring DataType, - TableParameterHandler &Handler) { - switch (Hash_djb2a(DataType)) { - case Hash_djb2a(L"bool"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"int16"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"int32"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"int64"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"uint16"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"uint32"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"uint64"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"float16"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"float32"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"float64"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - default: - LOG_ERROR_FMT_THROW(L"Unrecognized DataType: %ls for OpType: %ls.", - DataType.c_str(), OpTypeMd.OpTypeString.c_str()); - } -} + for (size_t I = 0; I < std::size(Values.InputValueSetNames); ++I) + Values.InputValueSetNames[I] = getInputValueSetName(I); -// Unary math ops don't support HLSLBool_t. If we included a dispatcher for -// them by allowing the generic dispatchTestByDataType then we would get -// compile errors for a bunch of the templated std lib functions we call to -// compute unary math ops. This is easier and cleaner than guarding against in -// at that point. -void OpTest::dispatchUnaryMathOpTestByDataType( - const OpTypeMetaData &OpTypeMd, std::wstring DataType, - TableParameterHandler &Handler) { - - switch (Hash_djb2a(DataType)) { - case Hash_djb2a(L"int16"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"int32"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"int64"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"uint16"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"uint32"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"uint64"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"float16"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"float32"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - case Hash_djb2a(L"float64"): - dispatchTestByVectorLength(OpTypeMd, Handler); - return; - default: - LOG_ERROR_FMT_THROW(L"Invalid UnaryMathOpType DataType: %ls.", - DataType.c_str()); - } -} + Values.ScalarInputFlags = GetScalarInputFlags(); -// Specialized dispatch for Trigonometric op tests (tan, sin, etc) -// Trig ops only support fp16, fp32, and fp64. So we don't want to -// to generate code paths for any other types. Emit a runtime error via -// LOG_ERROR_FMT_THROW if someone accidentally trys to add support for -// a different DataType. -void OpTest::dispatchTrigonometricOpTestByDataType( - const OpTypeMetaData &OpTypeMd, std::wstring DataType, - TableParameterHandler &Handler) { - - if (DataType == L"float16") - dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float32") - dispatchTestByVectorLength(OpTypeMd, Handler); - else if (DataType == L"float64") - dispatchTestByVectorLength(OpTypeMd, Handler); - else - LOG_ERROR_FMT_THROW( - L"Trigonometric ops are only supported for floating point types. " - L"DataType: %ls is not recognized.", - DataType.c_str()); -} + RuntimeParameters::TryGetValue(L"LongVectorInputSize", + Values.LongVectorInputSize); -template -void OpTest::dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMd, - TableParameterHandler &Handler) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + Values.VerboseLogging = VerboseLogging; - auto TestConfig = makeTestConfig(OpTypeMd); - TestConfig->setVerboseLogging(VerboseLogging); - auto OperandCount = TestConfig->getNumOperands(); - - std::wstring Name = L"InputValueSetName"; - for (size_t I = 0; I < OperandCount; I++) { - auto NameI = Name + std::to_wstring(I + 1); - std::wstring InputValueSetName( - Handler.GetTableParamByName(NameI.c_str())->m_str); - if (!InputValueSetName.empty()) - TestConfig->setInputValueSetKey(InputValueSetName, I); + return Values; } +}; - // Manual override to test a specific vector size. Convenient for debugging - // issues. - size_t InputSizeToTestOverride = 0; - WEX::TestExecution::RuntimeParameters::TryGetValue(L"LongVectorInputSize", - InputSizeToTestOverride); - - std::vector InputVectorSizes; - if (InputSizeToTestOverride) - InputVectorSizes.push_back(InputSizeToTestOverride); - else - InputVectorSizes = {3, 4, 5, 16, 17, 35, 100, 256, 1024}; - - for (auto SizeToTest : InputVectorSizes) { - // We could create a new config for each test case with the new length, but - // that feels wasteful. Instead, we just update the length to test. - TestConfig->setLengthToTest(SizeToTest); - testBaseMethod(TestConfig); - } -} +template +using InputSets = std::array, ARITY>; -template -void OpTest::testBaseMethod(std::unique_ptr> &TestConfig) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); +template +std::optional> +runTest(const TestConfig &Config, OP_TYPE OpType, + const InputSets &Inputs, size_t ExpectedOutputSize, + std::string ExtraDefines) { CComPtr D3DDevice; if (!createDevice(&D3DDevice, ExecTestUtils::D3D_SHADER_MODEL_6_9, false)) { @@ -573,26 +390,26 @@ void OpTest::testBaseMethod(std::unique_ptr> &TestConfig) { WEX::Logging::Log::Comment( "Device does not support SM 6.9. Can't run these tests."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - return; + return std::nullopt; #endif } - TestInputs Inputs = TestInputs(); - TestConfig->fillInputs(Inputs); - - TestConfig->computeExpectedValues(Inputs); - - if (VerboseLogging) { - logLongVector(Inputs.InputVector1, L"InputVector1"); - if (Inputs.InputVector2.has_value()) - logLongVector(Inputs.InputVector2.value(), L"InputVector2"); - if (Inputs.InputVector3.has_value()) - logLongVector(Inputs.InputVector3.value(), L"InputVector3"); + if (Config.VerboseLogging) { + for (size_t I = 0; I < ARITY; ++I) { + std::wstring Name = L"InputVector"; + Name += (wchar_t)(L'1' + I); + logLongVector(Inputs[I], Name); + } } // We have to construct the string outside of the lambda. Otherwise it's // cleaned up when the lambda finishes executing but before the shader runs. - std::string CompilerOptionsString = TestConfig->getCompilerOptionsString(); + std::string CompilerOptionsString = + getCompilerOptionsString(OpType, Inputs[0].size(), + Config.ScalarInputFlags, + std::move(ExtraDefines)); + + dxc::SpecificDllLoader DxilDllLoader; // The name of the shader we want to use in ShaderOpArith.xml. Could also add // logic to set this name in ShaderOpArithTable.xml so we can use different @@ -609,8 +426,9 @@ void OpTest::testBaseMethod(std::unique_ptr> &TestConfig) { std::shared_ptr TestResult = st::RunShaderOpTest( D3DDevice, DxilDllLoader, TestXML, ShaderName, [&](LPCSTR Name, std::vector &ShaderData, st::ShaderOp *ShaderOp) { - hlsl_test::LogCommentFmt(L"RunShaderOpTest CallBack. Resource Name: %S", - Name); + if (Config.VerboseLogging) + hlsl_test::LogCommentFmt( + L"RunShaderOpTest CallBack. Resource Name: %S", Name); // This callback is called once for each resource defined for // "LongVectorOp" in ShaderOpArith.xml. All callbacks are fired for each @@ -618,7 +436,7 @@ void OpTest::testBaseMethod(std::unique_ptr> &TestConfig) { // when they run. // Process the callback for the OutputVector resource. - if (0 == _stricmp(Name, "OutputVector")) { + if (_stricmp(Name, "OutputVector") == 0) { // We only need to set the compiler options string once. So this is a // convenient place to do it. ShaderOp->Shaders.at(0).Arguments = CompilerOptionsString.c_str(); @@ -626,35 +444,30 @@ void OpTest::testBaseMethod(std::unique_ptr> &TestConfig) { return; } - // Process the callback for the InputVector1 resource. - if (0 == _stricmp(Name, "InputVector1")) { - fillShaderBufferFromLongVectorData(ShaderData, Inputs.InputVector1); - return; - } - - // Process the callback for the InputVector2 resource. - if (0 == _stricmp(Name, "InputVector2")) { - if (Inputs.InputVector2.has_value()) - fillShaderBufferFromLongVectorData(ShaderData, - Inputs.InputVector2.value()); - return; - } - - // Process the callback for the InputVector3 resource. - if (0 == _stricmp(Name, "InputVector3")) { - if (Inputs.InputVector3.has_value()) - fillShaderBufferFromLongVectorData(ShaderData, - Inputs.InputVector3.value()); - return; + // Process the callback for the InputVector[1-3] resources + for (size_t I = 0; I < 3; ++I) { + std::string BufferName = "InputVector"; + BufferName += (char)('1' + I); + if (_stricmp(Name, BufferName.c_str()) == 0) { + if (I < ARITY) + fillShaderBufferFromLongVectorData(ShaderData, Inputs[I]); + return; + } } LOG_ERROR_FMT_THROW( L"RunShaderOpTest CallBack. Unexpected Resource Name: %S", Name); }); - // The TestConfig object handles the logic for extracting the shader output - // based on the op type. - VERIFY_SUCCEEDED(TestConfig->verifyOutput(TestResult)); + // Extract the data from the shader result + MappedData ShaderOutData; + TestResult->Test->GetReadBackData("OutputVector", &ShaderOutData); + + std::vector OutData; + fillLongVectorDataFromShaderBuffer(ShaderOutData, OutData, + ExpectedOutputSize); + + return OutData; } // Helper to fill the shader buffer based on type. Convenient to be used when @@ -692,35 +505,33 @@ void fillShaderBufferFromLongVectorData(std::vector &ShaderBuffer, return; } -// Returns the compiler options string to be used for the shader compilation. -// Reference ShaderOpArith.xml and the 'LongVectorOp' shader source to see how -// the defines are used in the shader code. -template -std::string TestConfig::getCompilerOptionsString() const { +template +std::string getCompilerOptionsString(OP_TYPE OpType, size_t VectorSize, + uint16_t ScalarInputFlags, + std::string ExtraDefines) { + OpTypeMetaData OpTypeMetaData = getOpTypeMetaData(OpType); - std::stringstream CompilerOptions(""); + std::stringstream CompilerOptions; if (is16BitType()) CompilerOptions << " -enable-16bit-types"; - CompilerOptions << " -DTYPE=" << getHLSLInputTypeString(); - CompilerOptions << " -DNUM=" << LengthToTest; + CompilerOptions << " -DTYPE=" << getHLSLTypeString(); + CompilerOptions << " -DNUM=" << VectorSize; CompilerOptions << " -DOPERATOR="; - if (Operator) - CompilerOptions << *Operator; + if (OpTypeMetaData.Operator) + CompilerOptions << *OpTypeMetaData.Operator; CompilerOptions << " -DFUNC="; - if (Intrinsic) - CompilerOptions << *Intrinsic; + if (OpTypeMetaData.Intrinsic) + CompilerOptions << *OpTypeMetaData.Intrinsic; - // For most of the ops this string is std::nullopt. - if (SpecialDefines) - CompilerOptions << " " << *SpecialDefines; + CompilerOptions << " " << ExtraDefines; - CompilerOptions << " -DOUT_TYPE=" << getHLSLOutputTypeString(); + CompilerOptions << " -DOUT_TYPE=" << getHLSLTypeString(); - CompilerOptions << " -DBASIC_OP_TYPE=" << getBasicOpTypeHexString(); + CompilerOptions << " -DBASIC_OP_TYPE=0x" << std::hex << ARITY; CompilerOptions << " -DOPERAND_IS_SCALAR_FLAGS="; CompilerOptions << "0x" << std::hex << ScalarInputFlags; @@ -728,572 +539,954 @@ std::string TestConfig::getCompilerOptionsString() const { return CompilerOptions.str(); } -template -std::string TestConfig::getBasicOpTypeHexString() const { - - if (BasicOpType == BasicOpType_Unary) - return "0x1"; - if (BasicOpType == BasicOpType_Binary) - return "0x2"; - if (BasicOpType == BasicOpType_Ternary) - return "0x3"; - - LOG_ERROR_FMT_THROW(L"Invalid BasicOpType: %d", - static_cast(BasicOpType)); - return "0x0"; +// +// asFloat +// + +template float asFloat(T); +template <> float asFloat(float A) { return float(A); } +template <> float asFloat(int32_t A) { return bit_cast(A); } +template <> float asFloat(uint32_t A) { return bit_cast(A); } + +// +// asFloat16 +// +template HLSLHalf_t asFloat16(T); +template <> HLSLHalf_t asFloat16(HLSLHalf_t A) { + return HLSLHalf_t(A.Val); } +template <> HLSLHalf_t asFloat16(int16_t A) { + return HLSLHalf_t(bit_cast(A)); +} +template <> HLSLHalf_t asFloat16(uint16_t A) { + return HLSLHalf_t(bit_cast(A)); +} + +// +// asInt +// -template size_t TestConfig::getNumOperands() const { - if (BasicOpType == BasicOpType_Unary) - return 1; +template int32_t asInt(T); +template <> int32_t asInt(float A) { return bit_cast(A); } +template <> int32_t asInt(int32_t A) { return A; } +template <> int32_t asInt(uint32_t A) { return bit_cast(A); } - if (BasicOpType == BasicOpType_Binary) - return 2; +// +// asInt16 +// - if (BasicOpType == BasicOpType_Ternary) - return 3; +template int16_t asInt16(T); +template <> int16_t asInt16(HLSLHalf_t A) { return bit_cast(A.Val); } +template <> int16_t asInt16(int16_t A) { return A; } +template <> int16_t asInt16(uint16_t A) { return bit_cast(A); } - LOG_ERROR_FMT_THROW(L"Invalid BasicOpType: %d", - static_cast(BasicOpType)); - return 0; +// +// asUint16 +// + +template uint16_t asUint16(T); +template <> uint16_t asUint16(HLSLHalf_t A) { + return bit_cast(A.Val); +} +template <> uint16_t asUint16(uint16_t A) { return A; } +template <> uint16_t asUint16(int16_t A) { return bit_cast(A); } + +// +// asUint +// + +template unsigned int asUint(T); +template <> unsigned int asUint(unsigned int A) { return A; } +template <> unsigned int asUint(float A) { return bit_cast(A); } +template <> unsigned int asUint(int A) { return bit_cast(A); } + +// +// splitDouble +// + +static void splitDouble(const double A, uint32_t &LowBits, uint32_t &HighBits) { + uint64_t Bits = 0; + std::memcpy(&Bits, &A, sizeof(Bits)); + LowBits = static_cast(Bits & 0xFFFFFFFF); + HighBits = static_cast(Bits >> 32); } -template -std::vector TestConfig::getInputValueSet(size_t ValueSetIndex) const { - if (BasicOpType == BasicOpType_Unary && ValueSetIndex == 0) - return getInputValueSetByKey(InputValueSetKeys[ValueSetIndex]); +// +// asDouble +// - if (BasicOpType == BasicOpType_Binary && ValueSetIndex <= 1) - return getInputValueSetByKey(InputValueSetKeys[ValueSetIndex]); +static double asDouble(const uint32_t LowBits, const uint32_t HighBits) { + uint64_t Bits = (static_cast(HighBits) << 32) | LowBits; + double Result; + std::memcpy(&Result, &Bits, sizeof(Result)); + return Result; +} - if (BasicOpType == BasicOpType_Ternary && ValueSetIndex <= 2) - return getInputValueSetByKey(InputValueSetKeys[ValueSetIndex]); +template struct TrigonometricOperation { + static T acos(T Val) { return std::acos(Val); } + static T asin(T Val) { return std::asin(Val); } + static T atan(T Val) { return std::atan(Val); } + static T cos(T Val) { return std::cos(Val); } + static T cosh(T Val) { return std::cosh(Val); } + static T sin(T Val) { return std::sin(Val); } + static T sinh(T Val) { return std::sinh(Val); } + static T tan(T Val) { return std::tan(Val); } + static T tanh(T Val) { return std::tanh(Val); } +}; - LOG_ERROR_FMT_THROW(L"Invalid ValueSetIndex: %d for OpType: %ls", - ValueSetIndex, OpTypeName.c_str()); - return std::vector(); +template const wchar_t *DataTypeName() { + static_assert(false && "Missing data type name"); } -template -std::string TestConfig::getHLSLOutputTypeString() const { - // std::visit allows us to dispatch a call to getHLSLTypeString() with the - // the current underlying element type of ExpectedVector. - return std::visit( - [](const auto &Vec) { - using ElementType = typename std::decay_t::value_type; - return getHLSLTypeString(); - }, - ExpectedVector); +#define DATA_TYPE_NAME(TYPE, NAME) \ + template <> const wchar_t *DataTypeName() { return NAME; } + +DATA_TYPE_NAME(HLSLBool_t, L"bool"); +DATA_TYPE_NAME(int16_t, L"int16"); +DATA_TYPE_NAME(int32_t, L"int32"); +DATA_TYPE_NAME(int64_t, L"int64"); +DATA_TYPE_NAME(uint16_t, L"uint16"); +DATA_TYPE_NAME(uint32_t, L"uint32"); +DATA_TYPE_NAME(uint64_t, L"uint64"); +DATA_TYPE_NAME(HLSLHalf_t, L"float16"); +DATA_TYPE_NAME(float, L"float32"); +DATA_TYPE_NAME(double, L"float64"); + +#undef DATA_TYPE_NAME + +template +std::vector buildTestInput(const wchar_t *InputValueSetName, + size_t SizeToTest) { + // TODO: remove the need to build up a RawValueSet, only to use that to build + // ValueSet. + std::vector RawValueSet = + getInputValueSetByKey(InputValueSetName); + + std::vector ValueSet; + ValueSet.reserve(SizeToTest); + for (size_t I = 0; I < SizeToTest; ++I) + ValueSet.push_back(RawValueSet[I % RawValueSet.size()]); + + return ValueSet; } -template -bool TestConfig::verifyOutput( - const std::shared_ptr &TestResult) { - - // std::visit allows us to dispatch a call to the private version of - // verifyOutput using a std::vector that matches the type currently held in - // ExpectedVector. This works because ExpectedVector is a std::variant of - // vector types, and the lambda receives the active type at runtime. It's - // important that the TestConfig instance has correctly assigned the expected - // output type to ExpectedVector. By default, this is std::vector, - // but ops like AsTypeOpType must override it. For example, - // AsTypeOpType_AsFloat16 sets ExpectedVector to std::vector. - return std::visit( - [this, &TestResult](const auto &Vec) { - using ElementType = typename std::decay_t::value_type; - return this->verifyOutput(TestResult, Vec); - }, - ExpectedVector); +template +InputSets buildTestInputs(const TestConfig &Config, + size_t SizeToTest) { + InputSets Inputs; + for (size_t I = 0; I < ARITY; ++I) { + + uint16_t OperandScalarFlag = 1 << I; + bool IsOperandScalar = Config.ScalarInputFlags & OperandScalarFlag; + + if (Config.InputValueSetNames[I].IsEmpty()) { + LOG_ERROR_FMT_THROW( + L"Expected parameter 'InputValueSetName%d' not found.", I + 1); + continue; + } + + Inputs[I] = buildTestInput(Config.InputValueSetNames[I], + IsOperandScalar ? 1 : SizeToTest); + } + + return Inputs; } -// Private version of verifyOutput. Called by the public version of verifyOutput -// which resolves OutT based on the ExpectedVector type. Most intrinsics will -// have an OutT that matches the input type being tested (which is T). But some, -// such as the 'AsType' ops, i.e 'AsUint' have an OutT that doesn't match T. -template // Primary template from TestConfig -template // Secondary template for verifyOutput -bool TestConfig::verifyOutput( - const std::shared_ptr &TestResult, - const std::vector &ExpectedVector) { +struct ValidationConfig { + float Tolerance = 0.0f; + ValidationType Type = ValidationType_Epsilon; - WEX::Logging::Log::Comment(WEX::Common::String().Format( - L"verifyOutput with OpType: %ls ExpectedVector<%S>", OpTypeName.c_str(), - typeid(OutT).name())); + static ValidationConfig Epsilon(float Tolerance) { + return ValidationConfig{Tolerance, ValidationType_Epsilon}; + } - DXASSERT(!ExpectedVector.empty(), - "Programmer Error: ExpectedVector is empty."); + static ValidationConfig Ulp(float Tolerance) { + return ValidationConfig{Tolerance, ValidationType_Ulp}; + } +}; - MappedData ShaderOutData; - TestResult->Test->GetReadBackData("OutputVector", &ShaderOutData); +template +void runAndVerify(const TestConfig &Config, OP_TYPE OpType, + const InputSets &Inputs, + const std::vector &Expected, + std::string ExtraDefines, + const ValidationConfig &ValidationConfig) { - const size_t OutputVectorSize = ExpectedVector.size(); + std::optional> Actual = + runTest(Config, OpType, Inputs, Expected.size(), ExtraDefines); - std::vector ActualValues; - fillLongVectorDataFromShaderBuffer(ShaderOutData, ActualValues, - OutputVectorSize); + // If the test didn't run, don't verify anything. + if (!Actual) + return; - return doVectorsMatch(ActualValues, ExpectedVector, Tolerance, ValidationType, - VerboseLogging); + VERIFY_IS_TRUE(doVectorsMatch(*Actual, Expected, ValidationConfig.Tolerance, + ValidationConfig.Type, Config.VerboseLogging)); } -// Generic fillInput. Fill the inputs based on the OpType and the -// ScalarInputFlags. -template -void TestConfig::fillInputs(TestInputs &Inputs) const { +template +void dispatchUnaryTest(const TestConfig &Config, + const ValidationConfig &ValidationConfig, OP_TYPE OpType, + size_t VectorSize, OUT_TYPE (*Calc)(T), + std::string ExtraDefines) { - auto fillVecFromValueSet = [this](std::vector &Vec, size_t ValueSetIndex, - size_t Count) { - std::vector ValueSet = getInputValueSet(ValueSetIndex); - for (size_t Index = 0; Index < Count; Index++) - Vec.push_back(ValueSet[Index % ValueSet.size()]); - }; - - size_t ValueSetIndex = 0; + InputSets Inputs = buildTestInputs(Config, VectorSize); - fillVecFromValueSet(Inputs.InputVector1, ValueSetIndex++, LengthToTest); + std::vector Expected; + Expected.reserve(Inputs[0].size()); - if (BasicOpType == BasicOpType_Unary) - return; + for (size_t I = 0; I < Inputs[0].size(); ++I) + Expected.push_back(Calc(Inputs[0][I])); - DXASSERT_NOMSG(BasicOpType == BasicOpType_Binary || - BasicOpType == BasicOpType_Ternary); + runAndVerify(Config, OpType, Inputs, Expected, ExtraDefines, + ValidationConfig); +} - const size_t Input2Length = - (ScalarInputFlags & SCALAR_INPUT_FLAGS_OPERAND_2_IS_SCALAR) - ? 1 - : LengthToTest; - Inputs.InputVector2 = std::vector(); - fillVecFromValueSet(*Inputs.InputVector2, ValueSetIndex++, Input2Length); +template +void dispatchBinaryTest(const TestConfig &Config, + const ValidationConfig &ValidationConfig, + OP_TYPE OpType, size_t VectorSize, + OUT_TYPE (*Calc)(T, T)) { + InputSets Inputs = buildTestInputs(Config, VectorSize); - if (BasicOpType == BasicOpType_Binary) - return; + std::vector Expected; + Expected.reserve(Inputs[0].size()); - DXASSERT_NOMSG(BasicOpType == BasicOpType_Ternary); + for (size_t I = 0; I < Inputs[0].size(); ++I) { + size_t Index1 = (Config.ScalarInputFlags & (1 << 1)) ? 0 : I; + Expected.push_back(Calc(Inputs[0][I], Inputs[1][Index1])); + } - const size_t Input3Length = - (ScalarInputFlags & SCALAR_INPUT_FLAGS_OPERAND_3_IS_SCALAR) - ? 1 - : LengthToTest; - Inputs.InputVector3 = std::vector(); - fillVecFromValueSet(*Inputs.InputVector3, ValueSetIndex++, Input3Length); + runAndVerify(Config, OpType, Inputs, Expected, "", ValidationConfig); } -template -AsTypeOpTestConfig::AsTypeOpTestConfig( - const OpTypeMetaData &OpTypeMd) - : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) { +// +// TrigonometricTest +// - BasicOpType = BasicOpType_Unary; +template +void dispatchTrigonometricTest(const TestConfig &Config, + ValidationConfig ValidationConfig, + TrigonometricOpType OpType, size_t VectorSize) { +#define DISPATCH(OP, NAME) \ + case OP: \ + return dispatchUnaryTest(Config, ValidationConfig, OP, VectorSize, \ + TrigonometricOperation::NAME, "") switch (OpType) { - case AsTypeOpType_AsFloat16: { - auto ComputeFunc = [this](const T &Val) { return asFloat16(Val); }; - InitUnaryOpValueComputer(ComputeFunc); - break; - } - case AsTypeOpType_AsFloat: { - auto ComputeFunc = [this](const T &Val) { return asFloat(Val); }; - InitUnaryOpValueComputer(ComputeFunc); - break; - } - case AsTypeOpType_AsInt: { - auto ComputeFunc = [this](const T &Val) { return asInt(Val); }; - InitUnaryOpValueComputer(ComputeFunc); - break; - } - case AsTypeOpType_AsInt16: { - auto ComputeFunc = [this](const T &Val) { return asInt16(Val); }; - InitUnaryOpValueComputer(ComputeFunc); - break; - } - case AsTypeOpType_AsUint: { - auto ComputeFunc = [this](const T &Val) { return asUint(Val); }; - InitUnaryOpValueComputer(ComputeFunc); - break; - } - case AsTypeOpType_AsUint_SplitDouble: { - SpecialDefines = " -DFUNC_ASUINT_SPLITDOUBLE=1"; - break; - } - case AsTypeOpType_AsUint16: { - auto ComputeFunc = [this](const T &Val) { return asUint16(Val); }; - InitUnaryOpValueComputer(ComputeFunc); - break; - } - case AsTypeOpType_AsDouble: { - BasicOpType = BasicOpType_Binary; - auto ComputeFunc = [this](const T &A, const T &B) { - return asDouble(A, B); - }; - InitBinaryOpValueComputer(ComputeFunc); + DISPATCH(TrigonometricOpType_Acos, acos); + DISPATCH(TrigonometricOpType_Asin, asin); + DISPATCH(TrigonometricOpType_Atan, atan); + DISPATCH(TrigonometricOpType_Cos, cos); + DISPATCH(TrigonometricOpType_Cosh, cosh); + DISPATCH(TrigonometricOpType_Sin, sin); + DISPATCH(TrigonometricOpType_Sinh, sinh); + DISPATCH(TrigonometricOpType_Tan, tan); + DISPATCH(TrigonometricOpType_Tanh, tanh); + case TrigonometricOpType_EnumValueCount: break; } - default: - LOG_ERROR_FMT_THROW(L"Unsupported AsTypeOpType: %ls", OpTypeName.c_str()); - } + +#undef DISPATCH + + LOG_ERROR_FMT_THROW(L"Unexpected TrigonometricOpType: %d.", OpType); } -template -void TestConfig::computeExpectedValues(const TestInputs &Inputs) { +void dispatchTestByOpTypeAndVectorSize(const TestConfig &Config, + TrigonometricOpType OpType, + size_t VectorSize) { - // Either a ExpectedValueComputer member should be set or the deriving class - // should have overridden computeExpectedValues. - if (!ExpectedValueComputer) - LOG_ERROR_FMT_THROW( - L"Programmer Error: ExpectedValueComputer is not set for OpType: %ls.", - OpTypeName.c_str()); + // All trigonometric ops are floating point types. + // These trig functions are defined to have a max absolute error of 0.0008 + // as per the D3D functional specs. An example with this spec for sin and + // cos is available here: + // https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#22.10.20 + + if (Config.DataType == DataTypeName()) + return dispatchTrigonometricTest( + Config, ValidationConfig::Epsilon(0.0010f), OpType, VectorSize); - ExpectedVector = ExpectedValueComputer->computeExpectedValues(Inputs); + if (Config.DataType == DataTypeName()) + return dispatchTrigonometricTest( + Config, ValidationConfig::Epsilon(0.0008f), OpType, VectorSize); + + LOG_ERROR_FMT_THROW( + L"DataType '%s' not supported for trigonometric operations.", + (const wchar_t *)Config.DataType); } -template -void AsTypeOpTestConfig::computeExpectedValues(const TestInputs &Inputs) { +// +// AsTypeOp +// - if (BasicOpType != BasicOpType_Unary && BasicOpType != BasicOpType_Binary) - LOG_ERROR_FMT_THROW(L"Programmer Error: computeExpectedValue called with " - L"unexpected BasicOpType: %d", - static_cast(BasicOpType)); +void dispatchAsUintSplitDoubleTest(const TestConfig &Config, + size_t VectorSize) { - if (ExpectedValueComputer) - ExpectedVector = ExpectedValueComputer->computeExpectedValues(Inputs); - else - // Only SplitDouble has special handling. All other ops will have an - // ExpectedValueComputer set. - computeExpectedValues_SplitDouble(Inputs.InputVector1); -} + InputSets Inputs = buildTestInputs(Config, VectorSize); -template -void AsTypeOpTestConfig::computeExpectedValues_SplitDouble( - const std::vector &InputVector) { - - DXASSERT_NOMSG(OpType == AsTypeOpType_AsUint_SplitDouble); - - // SplitDouble is a special case. We fill the first half of the expected - // vector with the expected low bits of each input double and the second - // half with the high bits of each input double. Doing things this way - // helps keep the rest of the generic logic in the LongVector test code - // simple. - std::vector Values; - Values.resize(InputVector.size() * 2); - - uint32_t LowBits, HighBits; - const size_t InputSize = InputVector.size(); - - for (size_t Index = 0; Index < InputSize; ++Index) { - splitDouble(InputVector[Index], LowBits, HighBits); - Values[Index] = LowBits; - Values[Index + InputSize] = HighBits; + std::vector Expected; + Expected.resize(Inputs.size() * 2); + + for (size_t I = 0; I < Inputs.size(); ++I) { + uint32_t Low, High; + splitDouble(Expected[I], Low, High); + Expected[I] = Low; + Expected[I + Inputs.size()] = High; } - ExpectedVector = std::move(Values); + ValidationConfig ValidationConfig{}; + runAndVerify(Config, AsTypeOpType_AsUint_SplitDouble, Inputs, Expected, + " -DFUNC_ASUINT_SPLITDOUBLE=1", ValidationConfig); } -template -TrigonometricOpTestConfig::TrigonometricOpTestConfig( - const OpTypeMetaData &OpTypeMd) - : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) { +void dispatchTestByOpTypeAndVectorSize(const TestConfig &Config, + AsTypeOpType OpType, size_t VectorSize) { - static_assert( - isFloatingPointType(), - "Trigonometric ops are only supported for floating point types."); + // Different AsType* operations are supported for different data types, so we + // dispatch on operation first. - BasicOpType = BasicOpType_Unary; +#define DISPATCH(TYPE, FN) \ + if (Config.DataType == DataTypeName()) \ + return dispatchUnaryTest(Config, ValidationConfig{}, OpType, \ + VectorSize, FN, "") - // All trigonometric ops are floating point types. - // These trig functions are defined to have a max absolute error of 0.0008 - // as per the D3D functional specs. An example with this spec for sin and - // cos is available here: - // https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#22.10.20 - ValidationType = ValidationType_Epsilon; - if (std::is_same_v) - Tolerance = 0.0010f; - else if (std::is_same_v) - Tolerance = 0.0008f; + switch (OpType) { + case AsTypeOpType_AsFloat: + DISPATCH(float, asFloat); + DISPATCH(int32_t, asFloat); + DISPATCH(uint32_t, asFloat); + break; - auto ComputeFunc = [this](const T &A) { - return this->computeExpectedValue(A); - }; - InitUnaryOpValueComputer(ComputeFunc); -} + case AsTypeOpType_AsInt: + DISPATCH(float, asInt); + DISPATCH(int32_t, asInt); + DISPATCH(uint32_t, asInt); + break; -// computeExpectedValue Trigonometric -template -T TrigonometricOpTestConfig::computeExpectedValue(const T &A) const { + case AsTypeOpType_AsUint: + DISPATCH(int32_t, asUint); + DISPATCH(uint32_t, asUint); + break; - switch (OpType) { - case TrigonometricOpType_Acos: - return std::acos(A); - case TrigonometricOpType_Asin: - return std::asin(A); - case TrigonometricOpType_Atan: - return std::atan(A); - case TrigonometricOpType_Cos: - return std::cos(A); - case TrigonometricOpType_Cosh: - return std::cosh(A); - case TrigonometricOpType_Sin: - return std::sin(A); - case TrigonometricOpType_Sinh: - return std::sinh(A); - case TrigonometricOpType_Tan: - return std::tan(A); - case TrigonometricOpType_Tanh: - return std::tanh(A); - default: - LOG_ERROR_FMT_THROW(L"Unknown TrigonometricOpType: %ls", - OpTypeName.c_str()); - return T(); + case AsTypeOpType_AsFloat16: + DISPATCH(HLSLHalf_t, asFloat16); + DISPATCH(int16_t, asFloat16); + DISPATCH(uint16_t, asFloat16); + break; + + case AsTypeOpType_AsInt16: + DISPATCH(HLSLHalf_t, asInt16); + DISPATCH(int16_t, asInt16); + DISPATCH(uint16_t, asInt16); + break; + + case AsTypeOpType_AsUint16: + DISPATCH(HLSLHalf_t, asUint16); + DISPATCH(int16_t, asUint16); + DISPATCH(uint16_t, asUint16); + break; + + case AsTypeOpType_AsUint_SplitDouble: + if (Config.DataType == DataTypeName()) + return dispatchAsUintSplitDoubleTest(Config, VectorSize); + break; + + case AsTypeOpType_AsDouble: + if (Config.DataType == DataTypeName()) + return dispatchBinaryTest(Config, ValidationConfig{}, + AsTypeOpType_AsDouble, VectorSize, + asDouble); + break; + + case AsTypeOpType_EnumValueCount: + break; } + +#undef DISPATCH + + LOG_ERROR_FMT_THROW(L"DataType '%s' not supported for AsTypeOp '%s'", + (const wchar_t *)Config.DataType, + (const wchar_t *)Config.OpTypeEnum); } -template -UnaryOpTestConfig::UnaryOpTestConfig( - const OpTypeMetaData &OpTypeMd) - : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) { +// +// UnaryOp +// - BasicOpType = BasicOpType_Unary; +template T Initialize(T V) { return V; } + +void dispatchTestByOpTypeAndVectorSize(const TestConfig &Config, + UnaryOpType OpType, size_t VectorSize) { +#define DISPATCH(TYPE, FUNC, EXTRA_DEFINES) \ + if (Config.DataType == DataTypeName()) \ + return dispatchUnaryTest(Config, ValidationConfig{}, OpType, VectorSize, \ + FUNC, EXTRA_DEFINES) + +#define DISPATCH_INITIALIZE(TYPE) \ + DISPATCH(TYPE, Initialize, " -DFUNC_INITIALIZE=1") switch (OpType) { case UnaryOpType_Initialize: - SpecialDefines = " -DFUNC_INITIALIZE=1"; + DISPATCH_INITIALIZE(HLSLBool_t); + DISPATCH_INITIALIZE(int16_t); + DISPATCH_INITIALIZE(int32_t); + DISPATCH_INITIALIZE(int64_t); + DISPATCH_INITIALIZE(uint16_t); + DISPATCH_INITIALIZE(uint32_t); + DISPATCH_INITIALIZE(uint64_t); + DISPATCH_INITIALIZE(HLSLHalf_t); + DISPATCH_INITIALIZE(float); + DISPATCH_INITIALIZE(double); + break; + case UnaryOpType_EnumValueCount: break; - default: - LOG_ERROR_FMT_THROW(L"Unsupported UnaryOpType: %ls", OpTypeName.c_str()); } - auto ComputeFunc = [this](const T &A) { - return this->computeExpectedValue(A); - }; - InitUnaryOpValueComputer(ComputeFunc); -} +#undef DISPATCH_INITIALIZE +#undef DISPATCH -template -T UnaryOpTestConfig::computeExpectedValue(const T &A) const { - if (OpType != UnaryOpType_Initialize) { - LOG_ERROR_FMT_THROW(L"computeExpectedValue(const T &A, " - L"UnaryOpType OpType) called on an " - L"unrecognized unary op: %ls", - OpTypeName.c_str()); - return T(); - } - - return T(A); + LOG_ERROR_FMT_THROW(L"DataType '%s' not supported for UnaryOpType '%s'", + (const wchar_t *)Config.DataType, + (const wchar_t *)Config.OpTypeEnum); } -template -UnaryMathOpTestConfig::UnaryMathOpTestConfig( - const OpTypeMetaData &OpTypeMd) - : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) { +// +// UnaryMathOp +// - BasicOpType = BasicOpType_Unary; +template +void dispatchUnaryMathOpTest(const TestConfig &Config, UnaryMathOpType OpType, + size_t VectorSize, OUT_TYPE (*Calc)(T)) { + + ValidationConfig ValidationConfig; if (isFloatingPointType()) { - Tolerance = 1; - ValidationType = ValidationType_Ulp; + ValidationConfig = ValidationConfig::Ulp(1.0); } - switch (OpType) { - case UnaryMathOpType_Sign: { - // Sign has overridden special logic. - auto ComputeFunc = [this](const T &A) { return this->sign(A); }; - InitUnaryOpValueComputer(ComputeFunc); - break; - } - case UnaryMathOpType_Frexp: - // Don't initialize a ValueComputer, Frexp has special logic for handling - // its output - SpecialDefines = " -DFUNC_FREXP=1"; - break; - default: { - auto ComputeFunc = [this](const T &A) { - return this->computeExpectedValue(A); - }; - InitUnaryOpValueComputer(ComputeFunc); + dispatchUnaryTest(Config, ValidationConfig, OpType, VectorSize, Calc, ""); +} + +template struct UnaryMathOps { + static T Abs(T V) { + if constexpr (std::is_unsigned_v) + return V; + else + return static_cast(std::abs(V)); } + + static int32_t Sign(T V) { + if (V > static_cast(0)) + return 1; + + if (V < static_cast(0)) + return -1; + + return 0; } -} -template -void UnaryMathOpTestConfig::computeExpectedValues( - const TestInputs &Inputs) { + static T Ceil(T V) { return std::ceil(V); } + static T Floor(T V) { return std::floor(V); } + static T Trunc(T V) { return std::trunc(V); } + static T Round(T V) { return std::round(V); } + static T Frac(T V) { return V - static_cast(std::floor(V)); } + static T Sqrt(T V) { return std::sqrt(V); } - // Base case - if (ExpectedValueComputer) { - ExpectedVector = ExpectedValueComputer->computeExpectedValues(Inputs); - return; + static T Rsqrt(T V) { + return static_cast(1.0) / static_cast(std::sqrt(V)); } - computeExpectedValues_Frexp(Inputs.InputVector1); -} + static T Exp(T V) { return std::exp(V); } + static T Exp2(T V) { return std::exp2(V); } + static T Log(T V) { return std::log(V); } + static T Log2(T V) { return std::log2(V); } + static T Log10(T V) { return std::log10(V); } + static T Rcp(T V) { return static_cast(1.0) / V; } +}; -// Frexp has a return value as well as an output paramater. So we handle it -// with special logic. Frexp is only supported for fp32 values. -template -void UnaryMathOpTestConfig::computeExpectedValues_Frexp( - const std::vector &InputVector) { +void dispatchFrexpTest(const TestConfig &Config, size_t VectorSize) { + // Frexp has a return value as well as an output paramater. So we handle it + // with special logic. Frexp is only supported for fp32 values. - DXASSERT_NOMSG(OpType == UnaryMathOpType_Frexp); + InputSets Inputs = buildTestInputs(Config, VectorSize); - std::vector Values; + std::vector Expected; // Expected values size is doubled. In the first half we store the Mantissas // and in the second half we store the Exponents. This way we can leverage the // existing logic which verify expected values in a single vector. We just // need to make sure that we organize the output in the same way in the shader // and when we read it back. - const size_t InputSize = InputVector.size(); - Values.resize(InputSize * 2); - float Exp = 0; - float Man = 0; - - for (size_t Index = 0; Index < InputSize; ++Index) { - Man = frexp(InputVector[Index], &Exp); - Values[Index] = Man; - Values[Index + InputSize] = Exp; + + Expected.resize(VectorSize * 2); + + for (size_t I = 0; I < VectorSize; ++I) { + int Exp = 0; + float Man = std::frexp(Inputs[0][I], &Exp); + + // std::frexp returns a signed mantissa. But the HLSL implmentation returns + // an unsigned mantissa. + Man = std::abs(Man); + + Expected[I] = Man; + + // std::frexp returns the exponent as an int, but HLSL stores it as a float. + // However, the HLSL exponents fractional component is always 0. So it can + // conversion between float and int is safe. + Expected[I + VectorSize] = static_cast(Exp); } - ExpectedVector = std::move(Values); + runAndVerify(Config, UnaryMathOpType_Frexp, Inputs, Expected, + " -DFUNC_FREXP=1", ValidationConfig{}); } -template -T UnaryMathOpTestConfig::computeExpectedValue(const T &A) const { +void dispatchTestByOpTypeAndVectorSize(const TestConfig &Config, + UnaryMathOpType OpType, + size_t VectorSize) { +#define DISPATCH(TYPE, FUNC) \ + if (Config.DataType == DataTypeName()) \ + return dispatchUnaryMathOpTest(Config, OpType, VectorSize, \ + UnaryMathOps::FUNC) - if constexpr (std::is_integral::value) { - // Abs and Sign are the only UnaryMathOps thats support integral types. - // Sign always returns int32 values, so its handled elsewhere. - DXASSERT_NOMSG(OpType == UnaryMathOpType_Abs); - return abs(A); - } + switch (OpType) { + case UnaryMathOpType_Abs: + DISPATCH(HLSLHalf_t, Abs); + DISPATCH(float, Abs); + DISPATCH(double, Abs); + DISPATCH(int16_t, Abs); + DISPATCH(int32_t, Abs); + DISPATCH(int64_t, Abs); + DISPATCH(uint16_t, Abs); + DISPATCH(uint32_t, Abs); + DISPATCH(uint64_t, Abs); + break; - if constexpr (!isFloatingPointType()) { - LOG_ERROR_FMT_THROW(L"Programmer error: UnaryMathOpType OpType: %ls only " - L"supports floating point types", - OpTypeName.c_str()); - return T(); - } + case UnaryMathOpType_Sign: + DISPATCH(HLSLHalf_t, Sign); + DISPATCH(float, Sign); + DISPATCH(double, Sign); + DISPATCH(int16_t, Sign); + DISPATCH(int32_t, Sign); + DISPATCH(int64_t, Sign); + DISPATCH(uint16_t, Sign); + DISPATCH(uint32_t, Sign); + DISPATCH(uint64_t, Sign); + break; - // Most of the std math functions here are only defined for floating point - // types. If we don't use a mechanism to ensure that we're only using floating - // point types then the compiler will complain about implicit conversions. - if constexpr (isFloatingPointType()) { - // A bunch of the std match functions here are wrapped in () to avoid - // collisions with the macro defitions for various functions in windows.h - switch (OpType) { - case UnaryMathOpType_Abs: - return abs(A); - case UnaryMathOpType_Ceil: - return (std::ceil)(A); - case UnaryMathOpType_Floor: - // float only - return (std::floor)(A); - case UnaryMathOpType_Trunc: - // float only - return (std::trunc)(A); - case UnaryMathOpType_Round: - // float only - return (std::round)(A); - case UnaryMathOpType_Frac: - // std::frac is not a standard C++ function, but we can implement it as - return A - T((std::floor)(A)); - case UnaryMathOpType_Sqrt: - return (std::sqrt)(A); - case UnaryMathOpType_Rsqrt: - // std::rsqrt is not a standard C++ function, but we can implement it as - return T(1.0) / T((std::sqrt)(A)); - case UnaryMathOpType_Exp: - return (std::exp)(A); - case UnaryMathOpType_Exp2: - return (std::exp2)(A); - case UnaryMathOpType_Log: - return (std::log)(A); - case UnaryMathOpType_Log2: - return (std::log2)(A); - case UnaryMathOpType_Log10: - return (std::log10)(A); - case UnaryMathOpType_Rcp: - // std::.rcp is not a standard C++ function, but we can implement it as - return T(1.0) / A; - default: - LOG_ERROR_FMT_THROW(L"computeExpectedValue(const T &A)" - L"called on an unrecognized unary math op: %ls", - OpTypeName.c_str()); - return T(); - } + case UnaryMathOpType_Ceil: + DISPATCH(HLSLHalf_t, Ceil); + DISPATCH(float, Ceil); + break; + + case UnaryMathOpType_Floor: + DISPATCH(HLSLHalf_t, Floor); + DISPATCH(float, Floor); + break; + + case UnaryMathOpType_Trunc: + DISPATCH(HLSLHalf_t, Trunc); + DISPATCH(float, Trunc); + break; + + case UnaryMathOpType_Round: + DISPATCH(HLSLHalf_t, Round); + DISPATCH(float, Round); + break; + + case UnaryMathOpType_Frac: + DISPATCH(HLSLHalf_t, Frac); + DISPATCH(float, Frac); + break; + + case UnaryMathOpType_Sqrt: + DISPATCH(HLSLHalf_t, Sqrt); + DISPATCH(float, Sqrt); + break; + + case UnaryMathOpType_Rsqrt: + DISPATCH(HLSLHalf_t, Rsqrt); + DISPATCH(float, Rsqrt); + break; + + case UnaryMathOpType_Exp: + DISPATCH(HLSLHalf_t, Exp); + DISPATCH(float, Exp); + break; + + case UnaryMathOpType_Exp2: + DISPATCH(HLSLHalf_t, Exp2); + DISPATCH(float, Exp2); + break; + + case UnaryMathOpType_Log: + DISPATCH(HLSLHalf_t, Log); + DISPATCH(float, Log); + break; + + case UnaryMathOpType_Log2: + DISPATCH(HLSLHalf_t, Log2); + DISPATCH(float, Log2); + break; + + case UnaryMathOpType_Log10: + DISPATCH(HLSLHalf_t, Log10); + DISPATCH(float, Log10); + break; + + case UnaryMathOpType_Rcp: + DISPATCH(HLSLHalf_t, Rcp); + DISPATCH(float, Rcp); + break; + + case UnaryMathOpType_Frexp: + if (Config.DataType == DataTypeName()) + return dispatchFrexpTest(Config, VectorSize); + break; + + case UnaryMathOpType_EnumValueCount: + break; } + +#undef DISPATCH + + LOG_ERROR_FMT_THROW(L"DataType '%s' not supported for UnaryOpType '%s'", + (const wchar_t *)Config.DataType, + (const wchar_t *)Config.OpTypeEnum); } -template -BinaryMathOpTestConfig::BinaryMathOpTestConfig( - const OpTypeMetaData &OpTypeMd) - : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) { +// +// BinaryMathOp +// - if (isFloatingPointType()) { - Tolerance = 1; - ValidationType = ValidationType_Ulp; - } +template +void dispatchBinaryMathOpTest(const TestConfig &Config, BinaryMathOpType OpType, + size_t VectorSize, OUT_TYPE (*Calc)(T, T)) { - BasicOpType = BasicOpType_Binary; + ValidationConfig ValidationConfig; - auto ComputeFunc = [this](const T &A, const T &B) { - return this->computeExpectedValue(A, B); - }; - InitBinaryOpValueComputer(ComputeFunc); + if (isFloatingPointType()) + ValidationConfig = ValidationConfig::Ulp(1.0); + + dispatchBinaryTest(Config, ValidationConfig, OpType, VectorSize, Calc); } -template -T BinaryMathOpTestConfig::computeExpectedValue(const T &A, - const T &B) const { +template struct BinaryMathOps { + static T Multiply(T A, T B) { return A * B; } + static T Add(T A, T B) { return A + B; } + static T Subtract(T A, T B) { return A - B; } + static T Divide(T A, T B) { return A / B; } + + static T FmodModulus(T A, T B) { + static_assert(isFloatingPointType()); + return std::fmod(A, B); + } + + static T OperatorModulus(T A, T B) { + // note: as well as integral types, HLSLHalf_t go through this code path + return A % B; + } + + // std::max and std::min are wrapped in () to avoid collisions with the macro + // defintions for min and max in windows.h + + static T Min(T A, T B) { return (std::min)(A, B); } + static T Max(T A, T B) { return (std::max)(A, B); } + + static T Ldexp(T A, T B) { return A * static_cast(std::pow(2.0f, B)); } +}; + +void dispatchTestByOpTypeAndVectorSize(const TestConfig &Config, + BinaryMathOpType OpType, + size_t VectorSize) { + +#define DISPATCH(TYPE, FUNC) \ + if (Config.DataType == DataTypeName()) \ + return dispatchBinaryMathOpTest(Config, OpType, VectorSize, \ + BinaryMathOps::FUNC) switch (OpType) { case BinaryMathOpType_Multiply: - return A * B; + DISPATCH(HLSLHalf_t, Multiply); + DISPATCH(float, Multiply); + DISPATCH(double, Multiply); + DISPATCH(int16_t, Multiply); + DISPATCH(int32_t, Multiply); + DISPATCH(int64_t, Multiply); + DISPATCH(uint16_t, Multiply); + DISPATCH(uint32_t, Multiply); + DISPATCH(uint64_t, Multiply); + break; + case BinaryMathOpType_Add: - return A + B; + DISPATCH(HLSLBool_t, Add); + DISPATCH(HLSLHalf_t, Add); + DISPATCH(float, Add); + DISPATCH(double, Add); + DISPATCH(int16_t, Add); + DISPATCH(int32_t, Add); + DISPATCH(int64_t, Add); + DISPATCH(uint16_t, Add); + DISPATCH(uint32_t, Add); + DISPATCH(uint64_t, Add); + break; + case BinaryMathOpType_Subtract: - return A - B; + DISPATCH(HLSLBool_t, Subtract); + DISPATCH(HLSLHalf_t, Subtract); + DISPATCH(float, Subtract); + DISPATCH(double, Subtract); + DISPATCH(int16_t, Subtract); + DISPATCH(int32_t, Subtract); + DISPATCH(int64_t, Subtract); + DISPATCH(uint16_t, Subtract); + DISPATCH(uint32_t, Subtract); + DISPATCH(uint64_t, Subtract); + break; + case BinaryMathOpType_Divide: - return A / B; + DISPATCH(HLSLHalf_t, Divide); + DISPATCH(float, Divide); + DISPATCH(double, Divide); + DISPATCH(int16_t, Divide); + DISPATCH(int32_t, Divide); + DISPATCH(int64_t, Divide); + DISPATCH(uint16_t, Divide); + DISPATCH(uint32_t, Divide); + DISPATCH(uint64_t, Divide); + break; + case BinaryMathOpType_Modulus: - return mod(A, B); + DISPATCH(HLSLHalf_t, OperatorModulus); + DISPATCH(float, FmodModulus); + DISPATCH(int16_t, OperatorModulus); + DISPATCH(int32_t, OperatorModulus); + DISPATCH(int64_t, OperatorModulus); + DISPATCH(uint16_t, OperatorModulus); + DISPATCH(uint32_t, OperatorModulus); + DISPATCH(uint64_t, OperatorModulus); + break; + case BinaryMathOpType_Min: - // std::max and std::min are wrapped in () to avoid collisions with the // - // macro defintions for min and max in windows.h - return (std::min)(A, B); + DISPATCH(HLSLHalf_t, Min); + DISPATCH(float, Min); + DISPATCH(double, Min); + DISPATCH(int16_t, Min); + DISPATCH(int32_t, Min); + DISPATCH(int64_t, Min); + DISPATCH(uint16_t, Min); + DISPATCH(uint32_t, Min); + DISPATCH(uint64_t, Min); + break; + case BinaryMathOpType_Max: - return (std::max)(A, B); + DISPATCH(HLSLHalf_t, Max); + DISPATCH(float, Max); + DISPATCH(double, Max); + DISPATCH(int16_t, Max); + DISPATCH(int32_t, Max); + DISPATCH(int64_t, Max); + DISPATCH(uint16_t, Max); + DISPATCH(uint32_t, Max); + DISPATCH(uint64_t, Max); + break; + case BinaryMathOpType_Ldexp: - return ldexp(A, B); - default: - LOG_ERROR_FMT_THROW(L"Unknown BinaryMathOpType: %ls", OpTypeName.c_str()); - return T(); + DISPATCH(HLSLHalf_t, Ldexp); + DISPATCH(float, Ldexp); + break; + + case BinaryMathOpType_EnumValueCount: + break; } + +#undef DISPATCH + + LOG_ERROR_FMT_THROW(L"DataType '%s' not supported for BinaryMathOpType '%s'", + (const wchar_t *)Config.DataType, + (const wchar_t *)Config.OpTypeEnum); } -template -TernaryMathOpTestConfig::TernaryMathOpTestConfig( - const OpTypeMetaData &OpTypeMd) - : TestConfig(OpTypeMd), OpType(OpTypeMd.OpType) { +// +// TernaryMathOp +// - if (isFloatingPointType()) { - Tolerance = 1; - ValidationType = ValidationType_Ulp; +template +void dispatchTernaryMathOpTest(const TestConfig &Config, + TernaryMathOpType OpType, size_t VectorSize, + OUT_TYPE (*Calc)(T, T, T)) { + + ValidationConfig ValidationConfig; + + if (isFloatingPointType()) + ValidationConfig = ValidationConfig::Ulp(1.0); + + InputSets Inputs = buildTestInputs(Config, VectorSize); + + std::vector Expected; + Expected.reserve(Inputs[0].size()); + + for (size_t I = 0; I < Inputs[0].size(); ++I) { + size_t Index1 = (Config.ScalarInputFlags & (1 << 1)) ? 0 : I; + size_t Index2 = (Config.ScalarInputFlags & (1 << 2)) ? 0 : I; + Expected.push_back( + Calc(Inputs[0][I], Inputs[1][Index1], Inputs[2][Index2])); } - BasicOpType = BasicOpType_Ternary; + runAndVerify(Config, OpType, Inputs, Expected, "", ValidationConfig); +} + +namespace TernaryMathOps { + +template T Fma(T, T, T); +template <> double Fma(double A, double B, double C) { return A * B + C; } + +template T Mad(T A, T B, T C) { return A * B + C; } + +template T SmoothStep(T Min, T Max, T X) { + DXASSERT_NOMSG(Min < Max); + + if (X <= Min) + return T(0); + if (X >= Max) + return T(1); + + T NormalizedX = (X - Min) / (Max - Min); + NormalizedX = std::clamp(NormalizedX, T(0), T(1)); + return NormalizedX * NormalizedX * (T(3) - T(2) * NormalizedX); +} + +} // namespace TernaryMathOps + +void dispatchTestByOpTypeAndVectorSize(const TestConfig &Config, + TernaryMathOpType OpType, + size_t VectorSize) { + +#define DISPATCH(TYPE, FUNC) \ + if (Config.DataType == DataTypeName()) \ + return dispatchTernaryMathOpTest(Config, OpType, VectorSize, \ + TernaryMathOps::FUNC) switch (OpType) { case TernaryMathOpType_Fma: + DISPATCH(double, Fma); + break; + case TernaryMathOpType_Mad: + DISPATCH(HLSLHalf_t, Mad); + DISPATCH(float, Mad); + DISPATCH(double, Mad); + DISPATCH(int16_t, Mad); + DISPATCH(int32_t, Mad); + DISPATCH(int64_t, Mad); + DISPATCH(uint16_t, Mad); + DISPATCH(uint32_t, Mad); + DISPATCH(uint64_t, Mad); + break; + case TernaryMathOpType_SmoothStep: + DISPATCH(HLSLHalf_t, SmoothStep); + DISPATCH(float, SmoothStep); + break; + + case TernaryMathOpType_EnumValueCount: break; - default: - LOG_ERROR_FMT_THROW(L"Invalid TernaryMathOpType: %ls", OpTypeName.c_str()); } - auto ComputeFunc = [this](const T &A, const T &B, const T &C) { - return this->computeExpectedValue(A, B, C); - }; - InitTernaryOpValueComputer(ComputeFunc); + LOG_ERROR_FMT_THROW(L"DataType '%s' not supported for TernaryMathOpType '%s'", + (const wchar_t *)Config.DataType, + (const wchar_t *)Config.OpTypeEnum); +} + +// +// dispatchTest +// + +template OP_TYPE GetOpType(const wchar_t *OpTypeString); + +template <> TrigonometricOpType GetOpType(const wchar_t *OpTypeString) { + return getTrigonometricOpType(OpTypeString).OpType; +} + +template <> UnaryOpType GetOpType(const wchar_t *OpTypeString) { + return getUnaryOpType(OpTypeString).OpType; +} + +template <> AsTypeOpType GetOpType(const wchar_t *OpTypeString) { + return getAsTypeOpType(OpTypeString).OpType; +} + +template <> UnaryMathOpType GetOpType(const wchar_t *OpTypeString) { + return getUnaryMathOpType(OpTypeString).OpType; +} + +template <> BinaryMathOpType GetOpType(const wchar_t *OpTypeString) { + return getBinaryMathOpType(OpTypeString).OpType; +} + +template <> TernaryMathOpType GetOpType(const wchar_t *OpTypeString) { + return getTernaryMathOpType(OpTypeString).OpType; +} + +template void dispatchTest(const TestConfig &Config) { + OP_TYPE OpType = GetOpType(Config.OpTypeEnum); + + std::vector InputVectorSizes; + if (Config.LongVectorInputSize) + InputVectorSizes.push_back(Config.LongVectorInputSize); + else + InputVectorSizes = {3, 4, 5, 16, 17, 35, 100, 256, 1024}; + + for (size_t VectorSize : InputVectorSizes) + dispatchTestByOpTypeAndVectorSize(Config, OpType, VectorSize); +} + +// TAEF test entry points + +TEST_F(OpTest, trigonometricOpTest) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + if (auto Config = TestConfig::Create(VerboseLogging)) + dispatchTest(*Config); +} + +TEST_F(OpTest, unaryOpTest) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + if (auto Config = TestConfig::Create(VerboseLogging)) + dispatchTest(*Config); +} + +TEST_F(OpTest, asTypeOpTest) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + if (auto Config = TestConfig::Create(VerboseLogging)) + dispatchTest(*Config); +} + +TEST_F(OpTest, unaryMathOpTest) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + if (auto Config = TestConfig::Create(VerboseLogging)) + dispatchTest(*Config); +} + +TEST_F(OpTest, binaryMathOpTest) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + if (auto Config = TestConfig::Create(VerboseLogging)) + dispatchTest(*Config); +} + +TEST_F(OpTest, ternaryMathOpTest) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + + if (auto Config = TestConfig::Create(VerboseLogging)) + dispatchTest(*Config); } -}; // namespace LongVector +} // namespace LongVector \ No newline at end of file diff --git a/tools/clang/unittests/HLSLExec/LongVectors.h b/tools/clang/unittests/HLSLExec/LongVectors.h index 3b03d501d4..fe7c2d6f1a 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.h +++ b/tools/clang/unittests/HLSLExec/LongVectors.h @@ -1,41 +1,18 @@ #ifndef LONGVECTORS_H #define LONGVECTORS_H -#include +#include + #include -#include -#include -#include #include -#include -#include #include #include -#include - #include "LongVectorTestData.h" -#include "ShaderOpTest.h" -#include "TableParameterHandler.h" -#include "dxc/Support/WinIncludes.h" -#include "dxc/Support/dxcapi.use.h" -#include "dxc/Test/HlslTestUtils.h" namespace LongVector { -// Used to compute the hash of a std::wstring at compile time. Gives us a way to -// create switch statements with a std::wstring. -// Note: Because this is evaluated at compile time the compiler detects hash -// collisions via an duplicate case statement error. -inline constexpr auto Hash_djb2a(const std::wstring_view String) { - unsigned long Hash{1337}; - for (wchar_t c : String) { - Hash = ((Hash << 5) + Hash) ^ static_cast(c); - } - return Hash; -} - // We don't have std::bit_cast in C++17, so we define our own version. template typename std::enable_if, std::vector, - std::vector, std::vector, std::vector, - std::vector, std::vector, - std::vector, std::vector, - std::vector>; - -template -void fillShaderBufferFromLongVectorData(std::vector &ShaderBuffer, - const std::vector &TestData); - -template -void fillLongVectorDataFromShaderBuffer(const MappedData &ShaderBuffer, - std::vector &TestData, - size_t NumElements); - template constexpr bool isFloatingPointType() { return std::is_same_v || std::is_same_v || std::is_same_v; @@ -76,16 +35,6 @@ template constexpr bool is16BitType() { std::is_same_v; } -template std::string getHLSLTypeString(); - -enum SCALAR_INPUT_FLAGS : uint16_t { - // SCALAR_INPUT_FLAGS_OPERAND_1_IS_SCALAR is intentionally omitted. Input 1 is - // always a vector. - SCALAR_INPUT_FLAGS_NONE = 0x0, - SCALAR_INPUT_FLAGS_OPERAND_2_IS_SCALAR = 0x2, - SCALAR_INPUT_FLAGS_OPERAND_3_IS_SCALAR = 0x4, -}; - // Helpful metadata struct so we can define some common properties for a test in // a single place. Intrinsic and Operator are passed in with -D defines to // the compiler and expanded as macros in the HLSL code. For a better @@ -111,7 +60,7 @@ template struct OpTypeMetaData { T OpType; std::optional Intrinsic = std::nullopt; std::optional Operator = std::nullopt; - uint16_t ScalarInputFlags = static_cast(SCALAR_INPUT_FLAGS_NONE); + uint16_t ScalarInputFlags = 0; }; template @@ -124,13 +73,6 @@ enum ValidationType { ValidationType_Ulp, }; -enum BasicOpType { - BasicOpType_Unary, - BasicOpType_Binary, - BasicOpType_Ternary, - BasicOpType_EnumValueCount -}; - enum UnaryOpType { UnaryOpType_Initialize, UnaryOpType_EnumValueCount }; static const OpTypeMetaData unaryOpTypeStringToOpMetaData[] = { @@ -344,8 +286,6 @@ std::vector getInputValueSetByKey(const std::wstring &Key, return std::vector(TestData::Data.at(Key)); } -// The TAEF test class. -template class TestConfig; // Forward declaration. class OpTest { public: BEGIN_TEST_CLASS(OpTest) @@ -383,666 +323,11 @@ class OpTest { L"Table:LongVectorOpTable.xml#AsTypeOpTable") END_TEST_METHOD() - template - void dispatchTestByDataType(const OpTypeMetaData &OpTypeMD, - std::wstring DataType, - TableParameterHandler &Handler); - - void dispatchTrigonometricOpTestByDataType( - const OpTypeMetaData &OpTypeMD, - std::wstring DataType, TableParameterHandler &Handler); - - void dispatchUnaryMathOpTestByDataType( - const OpTypeMetaData &OpTypeMD, std::wstring DataType, - TableParameterHandler &Handler); - - template - void dispatchTestByVectorLength(const OpTypeMetaData &OpTypeMD, - TableParameterHandler &Handler); - - template - void testBaseMethod(std::unique_ptr> &TestConfig); - private: - dxc::SpecificDllLoader DxilDllLoader; bool Initialized = false; bool VerboseLogging = false; }; -template -bool doValuesMatch(T A, T B, float Tolerance, ValidationType); -bool doValuesMatch(HLSLBool_t A, HLSLBool_t B, float, ValidationType); -bool doValuesMatch(HLSLHalf_t A, HLSLHalf_t B, float Tolerance, - ValidationType ValidationType); -bool doValuesMatch(float A, float B, float Tolerance, - ValidationType ValidationType); -bool doValuesMatch(double A, double B, float Tolerance, - ValidationType ValidationType); - -template -bool doVectorsMatch(const std::vector &ActualValues, - const std::vector &ExpectedValues, float Tolerance, - ValidationType ValidationType, bool VerboseLogging = false); - -template -void logLongVector(const std::vector &Values, const std::wstring &Name); - -// The TestInputs struct is used to help simplify calls to -// TestConfig::computeExpectedValues and to help us re-infer information about -// the number and type of inputs for the intrinsic being tested. -// InputVector1 represnts the first argument to the intrinsic being tested. -// InputVector1 is always present as all intrinsics being tested take at least -// one input. -// The other std::optional InputVectorN members represent the argument N for the -// intrinsic. They are stored as std::optional as they may not be present. It -// depends on the intrinsic being tested. -// The length of the InputVector is used to infer if the argument is a scalar -// or vector. -template struct TestInputs { - std::vector InputVector1; - std::optional> InputVector2 = std::nullopt; - std::optional> InputVector3 = std::nullopt; -}; - -// Base interface for an ExpectedValueComputer. Derived classes implement -// based on their operation type. This class and its derived classes are -// responsible for computing the expected values for a given set of inputs. -// This class pattern allows us to abstract common logic for filling the -// expected vector based on operation type. -template class ExpectedValueComputerBase { -public: - virtual ~ExpectedValueComputerBase() = default; - virtual VariantVector computeExpectedValues(const TestInputs &Inputs) = 0; -}; - -// Default T2 to T1 as most intrinsics have a return type that matches the input -// type. For intrinsics that don't match that pattern the caller can specify -// the output type explicitly via an argument for T2. -template -class UnaryOpExpectedValueComputer : public ExpectedValueComputerBase { -public: - using ComputeFuncPtr = std::function; - UnaryOpExpectedValueComputer(ComputeFuncPtr func) : ComputeFunc(func) {} - VariantVector computeExpectedValues(const TestInputs &Inputs) override { - VariantVector ExpectedVector = generateExpectedVector( - Inputs.InputVector1.size(), - [&](size_t Index) { return ComputeFunc(Inputs.InputVector1[Index]); }); - return ExpectedVector; - } - -private: - const ComputeFuncPtr ComputeFunc; -}; - -// Default T2 to T1 as most intrinsics have a return type that matches the input -// type. For intrinsics that don't match that pattern the caller can specify -// the output type explicitly via an argument for T2. -template -class BinaryOpExpectedValueComputer : public ExpectedValueComputerBase { -public: - using ComputeFuncPtr = std::function; - - BinaryOpExpectedValueComputer(ComputeFuncPtr func) : ComputeFunc(func) {} - - VariantVector computeExpectedValues(const TestInputs &Inputs) override { - - const auto &Input1 = Inputs.InputVector1; - - DXASSERT_NOMSG(Inputs.InputVector2.has_value()); - const auto &Input2 = Inputs.InputVector2.value(); - - VariantVector ExpectedVector = - generateExpectedVector(Input1.size(), [&](size_t Index) { - const T1 &B = (Input2.size() == 1 ? Input2[0] : Input2[Index]); - - return ComputeFunc(Input1[Index], B); - }); - - return ExpectedVector; - } - -private: - const ComputeFuncPtr ComputeFunc; -}; - -// Default T2 to T1 as most intrinsics have a return type that matches the input -// type. For intrinsics that don't match that pattern the caller can specify -// the output type explicitly via an argument for T2. -template -class TernaryOpExpectedValueComputer : public ExpectedValueComputerBase { - -public: - using ComputeFuncPtr = std::function; - - TernaryOpExpectedValueComputer(ComputeFuncPtr func) : ComputeFunc(func) {} - - VariantVector computeExpectedValues(const TestInputs &Inputs) override { - - const auto &Input1 = Inputs.InputVector1; - - DXASSERT_NOMSG(Inputs.InputVector2.has_value()); - const auto &Input2 = Inputs.InputVector2.value(); - - DXASSERT_NOMSG(Inputs.InputVector3.has_value()); - const auto &Input3 = Inputs.InputVector3.value(); - - VariantVector ExpectedVector = - generateExpectedVector(Input1.size(), [&](size_t Index) { - const T1 &B = (Input2.size() == 1 ? Input2[0] : Input2[Index]); - const T1 &C = (Input3.size() == 1 ? Input3[0] : Input3[Index]); - - return ComputeFunc(Input1[Index], B, C); - }); - - return ExpectedVector; - } - -private: - const ComputeFuncPtr ComputeFunc; -}; - -// Helps handle the test configuration for LongVector operations. -// It is particularly useful helping manage logic of computing expected values -// and verifying the output. Especially helpful due to templating on the -// different data types and giving us a relatively clean way to leverage -// different logic paths for different HLSL instrinsics while keeping the main -// test code pretty generic. -// For each *OpType enum defined a *TestConfig specialization should be -// implemented to handle the specific logic for that operation. Generally that -// includes some basic setup like setting the BasicOpType as well as logic to -// compute expected values. See ExpectedValueComputer* use and definitions for -// more details on computing expected values. -template class TestConfig { -public: - virtual ~TestConfig() = default; - - void fillInputs(TestInputs &Inputs) const; - - // Derived classes don't need to implement this function if they configure and - // set a ExpectedValueComputer member. - virtual void computeExpectedValues(const TestInputs &Inputs); - - void setInputValueSetKey(const std::wstring &InputValueSetName, - size_t Index) { - VERIFY_IS_TRUE(Index < (InputValueSetKeys.size()), - L"Index out of bounds for InputValueSetKeys"); - InputValueSetKeys[Index] = InputValueSetName; - } - - void setLengthToTest(size_t LengthToTest) { - this->LengthToTest = LengthToTest; - } - void setVerboseLogging(bool VerboseLogging) { - this->VerboseLogging = VerboseLogging; - } - - std::string getCompilerOptionsString() const; - - bool verifyOutput(const std::shared_ptr &TestResult); - - size_t getNumOperands() const; - std::string getBasicOpTypeHexString() const; - -private: - std::vector getInputValueSet(size_t ValueSetIndex) const; - - // Helpers to get the hlsl type as a string for a given C++ type. - std::string getHLSLInputTypeString() const { return getHLSLTypeString(); } - std::string getHLSLOutputTypeString() const; - - template - bool verifyOutput(const std::shared_ptr &TestResult, - const std::vector &ExpectedVector); - - // The input value sets are used to fill the shader buffer. - std::array InputValueSetKeys = {L"DefaultInputValueSet1", - L"DefaultInputValueSet2", - L"DefaultInputValueSet3"}; - -protected: - // Prevent instances of TestConfig from being created directly. Want to force - // a derived class to be used for creation. - template - TestConfig(const OpTypeMetaData &OpTypeMd) - : OpTypeName(OpTypeMd.OpTypeString), Intrinsic(OpTypeMd.Intrinsic), - Operator(OpTypeMd.Operator), - ScalarInputFlags(OpTypeMd.ScalarInputFlags) {} - - // Helper to initialize a unary value computer. - template - void InitUnaryOpValueComputer(std::function ComputeFunc) { - DXASSERT_NOMSG(BasicOpType == BasicOpType_Unary); - DXASSERT_NOMSG(ExpectedValueComputer == nullptr); - ExpectedValueComputer = - std::make_unique>(ComputeFunc); - } - - // Helper to initialize a binary value computer. - template - void InitBinaryOpValueComputer( - std::function ComputeFunc) { - DXASSERT_NOMSG(BasicOpType == BasicOpType_Binary); - DXASSERT_NOMSG(ExpectedValueComputer == nullptr); - ExpectedValueComputer = - std::make_unique>(ComputeFunc); - } - - // Helper to initialize a ternary value computer. - template - void InitTernaryOpValueComputer( - std::function ComputeFunc) { - DXASSERT_NOMSG(BasicOpType == BasicOpType_Ternary); - DXASSERT_NOMSG(ExpectedValueComputer == nullptr); - ExpectedValueComputer = - std::make_unique>(ComputeFunc); - } - - std::unique_ptr> ExpectedValueComputer = nullptr; - - // To be used for the value of -DOPERATOR - std::optional Operator; - // To be used for the value of -DFUNC - std::optional Intrinsic; - // Used to add special -D defines to the compiler options. - std::optional SpecialDefines = std::nullopt; - BasicOpType BasicOpType = BasicOpType_EnumValueCount; - float Tolerance = 0.0; - ValidationType ValidationType = ValidationType::ValidationType_Epsilon; - // Default the TypedOutputVector to use T, Ops that don't have a - // matching output type will override this. - VariantVector ExpectedVector = std::vector{}; - size_t LengthToTest = 0; - - // Just used for logging purposes. - std::wstring OpTypeName = L"UnknownOpType"; - bool VerboseLogging = false; - - const uint16_t ScalarInputFlags; -}; // class TestConfig - -// Specialized TestConfig for AsType operations. Implements logic for computing -// expected values. Also includes overrides of individual functions for -// computing expected values to provide runtime errors if an unsupported data -// type is used for a given intrinsic. See the individual overrides of functions -// for more details. -template class AsTypeOpTestConfig : public TestConfig { -public: - AsTypeOpTestConfig(const OpTypeMetaData &OpTypeMd); - - // Override the base class method so we can handle split double as it has two - // out parameters. - void computeExpectedValues(const TestInputs &Inputs) override; - -private: - void computeExpectedValues_SplitDouble(const std::vector &InputVector); - - template - HLSLHalf_t asFloat16([[maybe_unused]] const T &A) const { - LOG_ERROR_FMT_THROW(L"Programmer Error: Invalid AsFloat16 T: %s", - typeid(T).name()); - return HLSLHalf_t(); - } - - HLSLHalf_t asFloat16(const HLSLHalf_t &A) const { return HLSLHalf_t(A.Val); } - - HLSLHalf_t asFloat16(const int16_t &A) const { - return HLSLHalf_t(bit_cast(A)); - } - - HLSLHalf_t asFloat16(const uint16_t &A) const { - return HLSLHalf_t(bit_cast(A)); - } - - template float asFloat(const T &) const { - LOG_ERROR_FMT_THROW(L"Programmer Error: Invalid AsFloat T: %S", - typeid(T).name()); - return 0.0f; - } - - float asFloat(const float &A) const { return float(A); } - float asFloat(const int32_t &A) const { return bit_cast(A); } - float asFloat(const uint32_t &A) const { return bit_cast(A); } - - template int32_t asInt([[maybe_unused]] const T &A) const { - // This path is unexpected outside of an issue when brining up new tests. So - // throwing an exception is appropriate. - LOG_ERROR_FMT_THROW(L"Programmer Error: Invalid AsInt T: %S", - typeid(T).name()); - return 0; - } - - int32_t asInt(const float &A) const { return bit_cast(A); } - int32_t asInt(const int32_t &A) const { return A; } - int32_t asInt(const uint32_t &A) const { return bit_cast(A); } - - template int16_t asInt16([[maybe_unused]] const T &A) const { - // This path is unexpected outside of an issue when brining up new tests. So - // throwing an exception is appropriate. - LOG_ERROR_FMT_THROW(L"Programmer Error: Invalid AsInt16 T: %S", - typeid(T).name()); - return 0; - } - - int16_t asInt16(const HLSLHalf_t &A) const { - return bit_cast(A.Val); - } - int16_t asInt16(const int16_t &A) const { return A; } - int16_t asInt16(const uint16_t &A) const { return bit_cast(A); } - - template uint16_t asUint16([[maybe_unused]] const T &A) const { - // This path is unexpected outside of an issue when brining up new tests. So - // throwing an exception is appropriate. - LOG_ERROR_FMT_THROW(L"Programmer Error: Invalid AsUint16 T: %S", - typeid(T).name()); - return 0; - } - - uint16_t asUint16(const HLSLHalf_t &A) const { - return bit_cast(A.Val); - } - uint16_t asUint16(const uint16_t &A) const { return A; } - uint16_t asUint16(const int16_t &A) const { return bit_cast(A); } - - template unsigned int asUint([[maybe_unused]] const T &A) const { - // This path is unexpected outside of an issue when brining up new tests. So - // throwing an exception is appropriate. - LOG_ERROR_FMT_THROW(L"Programmer Error: Invalid AsUint T: %S", - typeid(T).name()); - return 0; - } - - unsigned int asUint(const unsigned int &A) const { return A; } - unsigned int asUint(const float &A) const { - return bit_cast(A); - } - unsigned int asUint(const int &A) const { return bit_cast(A); } - - template - void splitDouble([[maybe_unused]] const T &A, - [[maybe_unused]] uint32_t &LowBits, - [[maybe_unused]] uint32_t &HighBits) const { - // This path is unexpected outside of an issue when brining up new tests. So - // throwing an exception is appropriate. - LOG_ERROR_FMT_THROW(L"Programmer Error: splitDouble only accepts a double " - L"as input. Have DataTypeInT: %s", - typeid(T).name()); - } - - void splitDouble(const double &A, uint32_t &LowBits, - uint32_t &HighBits) const { - uint64_t Bits = 0; - std::memcpy(&Bits, &A, sizeof(Bits)); - LowBits = static_cast(Bits & 0xFFFFFFFF); - HighBits = static_cast(Bits >> 32); - } - - template - double asDouble([[maybe_unused]] const T &LowBits, - [[maybe_unused]] const T &HighBits) const { - // This path is unexpected outside of an issue when brining up new tests. So - // throwing an exception is appropriate. - LOG_ERROR_FMT_THROW(L"Programmer Error: asDouble only accepts two uint32_t " - L"inputs. Have T : %S", - typeid(T).name()); - return 0.0; - } - - double asDouble(const uint32_t &LowBits, const uint32_t &HighBits) const { - uint64_t Bits = (static_cast(HighBits) << 32) | LowBits; - double Result; - std::memcpy(&Result, &Bits, sizeof(Result)); - return Result; - } - - AsTypeOpType OpType = AsTypeOpType_EnumValueCount; -}; - -template class TrigonometricOpTestConfig : public TestConfig { -public: - TrigonometricOpTestConfig( - const OpTypeMetaData &OpTypeMd); - - T computeExpectedValue(const T &A) const; - -private: - TrigonometricOpType OpType = TrigonometricOpType_EnumValueCount; -}; - -template class UnaryOpTestConfig : public TestConfig { -public: - UnaryOpTestConfig(const OpTypeMetaData &OpTypeMd); - - T computeExpectedValue(const T &A) const; - -private: - UnaryOpType OpType = UnaryOpType_EnumValueCount; -}; - -template class UnaryMathOpTestConfig : public TestConfig { -public: - UnaryMathOpTestConfig(const OpTypeMetaData &OpTypeMd); - - // Override the base class method so we can handle frexp as it has - // an out parameter. - void computeExpectedValues(const TestInputs &Inputs) override; - T computeExpectedValue(const T &A) const; - -private: - void computeExpectedValues_Frexp(const std::vector &InputVector); - - UnaryMathOpType OpType = UnaryMathOpType_EnumValueCount; - - // The majority of HLSL intrinsics return a DataType matching the - // input DataType. However, Sign always returns an int32_t. - template int32_t sign(const T &A) const { - // Return 1 for positive, -1 for negative, 0 for zero. - // Wrap comparison operands in DataTypeInT constructor to make sure - // we are comparing the same type. - return A > T(0) ? 1 : A < T(0) ? -1 : 0; - } - - template T abs(const T &A) const { - if constexpr (std::is_unsigned::value) - return T(A); - else - // Cast to T for the int16_t case because std::abs implicitly - // converts to and returns an int. Without the cast back to an int the - // compiler will complain that the implicit conversion back to int16_t has - // lost precision. - return static_cast((std::abs)(A)); - } - - template - typename std::enable_if::value), float>::type - frexp([[maybe_unused]] const T &A, [[maybe_unused]] float *Exponent) const { - LOG_ERROR_FMT_THROW(L"Programmer Error: frexp only accepts floats. " - L"Have DataTypeT: %s", - typeid(T).name()); - return 0.0f; - } - - template - typename std::enable_if<(std::is_same::value), T>::type - frexp(const T &A, T *Exponent) const { - int IntExp = 0; - - // std::frexp returns a signed mantissa. But the HLSL implmentation returns - // an unsigned mantissa. - T mantissa = std::abs(std::frexp(A, &IntExp)); - - // std::frexp returns the exponent as an int, but HLSL stores it as a float. - // However, the HLSL exponents fractional component is always 0. So it can - // conversion between float and int is safe. - *Exponent = static_cast(IntExp); - return mantissa; - } -}; - -template class BinaryMathOpTestConfig : public TestConfig { -public: - BinaryMathOpTestConfig(const OpTypeMetaData &OpTypeMd); - - T computeExpectedValue(const T &A, const T &B) const; - -private: - BinaryMathOpType OpType = BinaryMathOpType_EnumValueCount; - - // Helpers so we do the right thing for float types. HLSLHalf_t is handled in - // an operator overload. - template T mod(const T &A, const T &B) const { - return A % B; - } - - template <> float mod(const float &A, const float &B) const { - return std::fmod(A, B); - } - - template <> double mod(const double &A, const double &B) const { - return std::fmod(A, B); - } - - template - typename std::enable_if::value || - std::is_same::value), - T>::type - ldexp([[maybe_unused]] const T &A, [[maybe_unused]] const T &Exponent) const { - LOG_ERROR_FMT_THROW(L"Programmer Error: ldexp only accepts floatlikes. " - L"Have T: %s", - typeid(T).name()); - return T(); - } - - template - typename std::enable_if<(std::is_same::value || - std::is_same::value), - T>::type - ldexp(const T &A, const T &Exponent) const { - return A * T(std::pow(2.0f, Exponent)); - } -}; - -template class TernaryMathOpTestConfig : public TestConfig { -public: - TernaryMathOpTestConfig(const OpTypeMetaData &OpTypeMd); - - T computeExpectedValue(const T &A, const T &B, const T &C) const { - switch (OpType) { - case TernaryMathOpType_Fma: - return fma(A, B, C); - case TernaryMathOpType_Mad: - return mad(A, B, C); - case TernaryMathOpType_SmoothStep: - return smoothStep(A, B, C); - default: - LOG_ERROR_FMT_THROW(L"Programmer Error: Invalid TernaryMathOpType: %d", - OpType); - return T(); - } - } - -private: - TernaryMathOpType OpType = TernaryMathOpType_EnumValueCount; - - template - T fma([[maybe_unused]] const T &A, [[maybe_unused]] const T &B, - const T &C) const { - LOG_ERROR_FMT_THROW(L"Programmer Error: fma only accepts doubles. Have " - L"T: %s", - typeid(T).name()); - return T(); - } - - // fma only accepts doubles - template <> - double fma(const double &A, const double &B, const double &C) const { - return A * B + C; - } - - // Mad is only enabled for numeric types. Capture that by having an fallback - // that errors out if bool is used. - template - typename std::enable_if::value, T>::type - mad([[maybe_unused]] const T &A, [[maybe_unused]] const T &B, - const T &C) const { - LOG_ERROR_FMT_THROW(L"Programmer Error: mad does not support HLSLBool_t"); - return T(); - } - - template - typename std::enable_if::value, T>::type - mad(const T &A, const T &B, const T &C) const { - return A * B + C; - } - - // Smoothstep Fallback: only enabled when T is NOT a floatlike - template - typename std::enable_if::value || - std::is_same::value || - std::is_same::value), - T>::type - smoothStep([[maybe_unused]] const T &Min, [[maybe_unused]] const T &Max, - [[maybe_unused]] const T &X) const { - LOG_ERROR_FMT_THROW(L"Programmer Error: smoothStep only accepts " - L"floatlikes. Have T: %s", - typeid(T).name()); - return T(); - } - - // Smoothstep is only enabled for floatlikes - template - typename std::enable_if::value || - std::is_same::value || - std::is_same::value, - T>::type - smoothStep(const T &Min, const T &Max, const T &X) const { - DXASSERT_NOMSG(Min < Max); - - if (X <= Min) - return T(0); - if (X >= Max) - return T(1); - - T NormalizedX = (X - Min) / (Max - Min); - NormalizedX = std::clamp(NormalizedX, T(0), T(1)); - return NormalizedX * NormalizedX * (T(3) - T(2) * NormalizedX); - } -}; - -template -std::unique_ptr> -makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { - return std::make_unique>(OpTypeMetaData); -} - -template -std::unique_ptr> -makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { - return std::make_unique>(OpTypeMetaData); -} - -template -std::unique_ptr> -makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { - return std::make_unique>(OpTypeMetaData); -} - -template -std::unique_ptr> -makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { - return std::make_unique>(OpTypeMetaData); -} - -template -std::unique_ptr> -makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { - return std::make_unique>(OpTypeMetaData); -} - -template -std::unique_ptr> -makeTestConfig(const OpTypeMetaData &OpTypeMetaData) { - return std::make_unique>(OpTypeMetaData); -} -}; // namespace LongVector +} // namespace LongVector #endif // LONGVECTORS_H