diff --git a/tools/clang/unittests/HLSLExec/LongVectorOps.def b/tools/clang/unittests/HLSLExec/LongVectorOps.def index f3908ff055..9cf1784239 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorOps.def +++ b/tools/clang/unittests/HLSLExec/LongVectorOps.def @@ -20,6 +20,7 @@ INPUT_SET(Bitwise) INPUT_SET(SelectCond) INPUT_SET(FloatSpecial) INPUT_SET(AllOnes) +INPUT_SET(WaveMultiPrefixBitwise) #undef INPUT_SET @@ -207,5 +208,10 @@ OP_DEFAULT_DEFINES(Wave, WaveReadLaneAt, 1, "TestWaveReadLaneAt", "", " -DFUNC_W OP_DEFAULT_DEFINES(Wave, WaveReadLaneFirst, 1, "TestWaveReadLaneFirst", "", " -DFUNC_WAVE_READ_LANE_FIRST=1") OP_DEFAULT_DEFINES(Wave, WavePrefixSum, 1, "TestWavePrefixSum", "", " -DFUNC_WAVE_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1") OP_DEFAULT_DEFINES(Wave, WavePrefixProduct, 1, "TestWavePrefixProduct", "", " -DFUNC_WAVE_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1") +OP(Wave, WaveMultiPrefixSum, 1, "TestWaveMultiPrefixSum", "", " -DFUNC_WAVE_MULTI_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", Default1, Default2, Default3) +OP(Wave, WaveMultiPrefixProduct, 1, "TestWaveMultiPrefixProduct", "", " -DFUNC_WAVE_MULTI_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", Default1, Default2, Default3) +OP(Wave, WaveMultiPrefixBitAnd, 1, "TestWaveMultiPrefixBitAnd", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_AND=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3) +OP(Wave, WaveMultiPrefixBitOr, 1, "TestWaveMultiPrefixBitOr", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_OR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3) +OP(Wave, WaveMultiPrefixBitXor, 1, "TestWaveMultiPrefixBitXor", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_XOR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3) #undef OP diff --git a/tools/clang/unittests/HLSLExec/LongVectorTestData.h b/tools/clang/unittests/HLSLExec/LongVectorTestData.h index 519f8a8b63..ce32b4c035 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorTestData.h +++ b/tools/clang/unittests/HLSLExec/LongVectorTestData.h @@ -290,6 +290,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits::min(), -1, 0, 1, 3, std::numeric_limits::max()); INPUT_SET(InputSet::SelectCond, 0, 1); INPUT_SET(InputSet::AllOnes, 1); +INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF, + -1); END_INPUT_SETS() BEGIN_INPUT_SETS(int32_t) @@ -304,6 +306,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits::min(), -1, 0, 1, 3, std::numeric_limits::max()); INPUT_SET(InputSet::SelectCond, 0, 1); INPUT_SET(InputSet::AllOnes, 1); +INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF, + -1); END_INPUT_SETS() BEGIN_INPUT_SETS(int64_t) @@ -318,6 +322,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits::min(), -1, 0, 1, 3, std::numeric_limits::max()); INPUT_SET(InputSet::SelectCond, 0, 1); INPUT_SET(InputSet::AllOnes, 1); +INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF, + -1ll); END_INPUT_SETS() BEGIN_INPUT_SETS(uint16_t) @@ -329,6 +335,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x5555, 0xAAAA, 0x8000, 127, std::numeric_limits::max()); INPUT_SET(InputSet::SelectCond, 0, 1); INPUT_SET(InputSet::AllOnes, 1); +INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF, + std::numeric_limits::max()); END_INPUT_SETS() BEGIN_INPUT_SETS(uint32_t) @@ -340,6 +348,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x55555555, 0xAAAAAAAA, 0x80000000, 127, std::numeric_limits::max()); INPUT_SET(InputSet::SelectCond, 0, 1); INPUT_SET(InputSet::AllOnes, 1); +INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0xA, 0xC, 0xF, + std::numeric_limits::max()); END_INPUT_SETS() BEGIN_INPUT_SETS(uint64_t) @@ -352,6 +362,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x5555555555555555, std::numeric_limits::max()); INPUT_SET(InputSet::SelectCond, 0, 1); INPUT_SET(InputSet::AllOnes, 1); +INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0xA, 0xC, 0xF, + std::numeric_limits::max()); END_INPUT_SETS() BEGIN_INPUT_SETS(HLSLHalf_t) diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp index 5ab16b75a8..a52eb56581 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.cpp +++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp @@ -1349,7 +1349,7 @@ template T waveActiveBitAnd(T A, UINT) { WAVE_OP(OpType::WaveActiveBitAnd, (waveActiveBitAnd(A, WaveSize))); template T waveActiveBitOr(T A, UINT) { - // We set the LSB to 0 in one of the lanes. + // We set the LSB to 1 in one of the lanes. return static_cast(A | static_cast(1)); } @@ -1362,6 +1362,60 @@ template T waveActiveBitXor(T A, UINT) { WAVE_OP(OpType::WaveActiveBitXor, (waveActiveBitXor(A, WaveSize))); +WAVE_OP(OpType::WaveMultiPrefixBitAnd, waveMultiPrefixBitAnd(A, WaveSize)); + +template T waveMultiPrefixBitAnd(T A, UINT) { + // All lanes in the group mask use a mask to filter for only the second and + // third LSBs. + return static_cast(A & static_cast(0x6)); +} + +WAVE_OP(OpType::WaveMultiPrefixBitOr, waveMultiPrefixBitOr(A, WaveSize)); + +template T waveMultiPrefixBitOr(T A, UINT) { + // All lanes in the group mask clear the second LSB. + return static_cast(A & ~static_cast(0x2)); +} + +template +struct Op : StrictValidation {}; + +template struct ExpectedBuilder { + static std::vector buildExpected(Op &, + const InputSets &Inputs, UINT) { + DXASSERT_NOMSG(Inputs.size() == 1); + + std::vector Expected; + const size_t VectorSize = Inputs[0].size(); + + // We get a little creative for MultiPrefixBitXor. The mask we use for the + // group in the shader is 0xE (0b1110), which includes lanes 1, 2, and 3. + // Prefix ops don't include the value of the current lane in their result. + // So, for this test we store the result of WaveMultiPrefixBitXor from lane + // 3. This means only the values from lanes 1 and 2 contribute to the result + // at lane 3. + // + // In the shader: + // - Lane 0: Set to 0 (not in mask, shouldn't affect result) + // - Lane 1: Keeps original input values + // - Lane 2: Lower half + last element set to 0, upper half keeps input + // - Lane 3: Stores the prefix XOR result (lanes 1 XOR lanes 2) + // + // Expected result: Lower half matches input (lane 1 XOR 0), upper half is + // 0s, except last element matches input. + for (size_t I = 0; I < VectorSize / 2; ++I) + Expected.push_back(Inputs[0][I]); + for (size_t I = VectorSize / 2; I < VectorSize - 1; ++I) + Expected.push_back(0); + + // We also set the last element to 0 on lane 2 so the last element in the + // output vector matches the last element in the input vector. + Expected.push_back(Inputs[0][VectorSize - 1]); + + return Expected; + } +}; + template struct Op : StrictValidation {}; @@ -1420,8 +1474,14 @@ template struct ExpectedBuilder { WAVE_OP(OpType::WavePrefixSum, (wavePrefixSum(A, WaveSize))); template T wavePrefixSum(T A, UINT WaveSize) { - // We test the prefix sume in the 'middle' lane. This choice is arbitrary. - return static_cast(A * static_cast(WaveSize / 2)); + // We test the prefix sum in the 'middle' lane. This choice is arbitrary. + return A * static_cast(WaveSize / 2); +} + +WAVE_OP(OpType::WaveMultiPrefixSum, (waveMultiPrefixSum(A, WaveSize))); + +template T waveMultiPrefixSum(T A, UINT) { + return A * static_cast(2u); } WAVE_OP(OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize))); @@ -1429,7 +1489,14 @@ WAVE_OP(OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize))); template T wavePrefixProduct(T A, UINT) { // We test the the prefix product in the 3rd lane to avoid overflow issues. // So the result is A * A. - return static_cast(A * A); + return A * A; +} + +WAVE_OP(OpType::WaveMultiPrefixProduct, (waveMultiPrefixProduct(A, WaveSize))); + +template T waveMultiPrefixProduct(T A, UINT) { + // The group mask has 3 lanes. + return A * A; } #undef WAVE_OP @@ -2343,6 +2410,11 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, int16_t); HLK_WAVEOP_TEST(WavePrefixSum, int16_t); HLK_WAVEOP_TEST(WavePrefixProduct, int16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, int16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int16_t); HLK_WAVEOP_TEST(WaveActiveSum, int32_t); HLK_WAVEOP_TEST(WaveActiveMin, int32_t); HLK_WAVEOP_TEST(WaveActiveMax, int32_t); @@ -2351,7 +2423,12 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneAt, int32_t); HLK_WAVEOP_TEST(WaveReadLaneFirst, int32_t); HLK_WAVEOP_TEST(WavePrefixSum, int32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, int32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int32_t); HLK_WAVEOP_TEST(WavePrefixProduct, int32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int32_t); HLK_WAVEOP_TEST(WaveActiveSum, int64_t); HLK_WAVEOP_TEST(WaveActiveMin, int64_t); HLK_WAVEOP_TEST(WaveActiveMax, int64_t); @@ -2361,7 +2438,14 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, int64_t); HLK_WAVEOP_TEST(WavePrefixSum, int64_t); HLK_WAVEOP_TEST(WavePrefixProduct, int64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, int64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int64_t); + // Note: WaveActiveBit* ops don't support uint16_t in HLSL + // But the WaveMultiPrefixBit ops support all int and uint types HLK_WAVEOP_TEST(WaveActiveSum, uint16_t); HLK_WAVEOP_TEST(WaveActiveMin, uint16_t); HLK_WAVEOP_TEST(WaveActiveMax, uint16_t); @@ -2371,11 +2455,15 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, uint16_t); HLK_WAVEOP_TEST(WavePrefixSum, uint16_t); HLK_WAVEOP_TEST(WavePrefixProduct, uint16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint16_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint16_t); HLK_WAVEOP_TEST(WaveActiveSum, uint32_t); HLK_WAVEOP_TEST(WaveActiveMin, uint32_t); HLK_WAVEOP_TEST(WaveActiveMax, uint32_t); HLK_WAVEOP_TEST(WaveActiveProduct, uint32_t); - // Note: WaveActiveBit* ops don't support uint16_t in HLSL HLK_WAVEOP_TEST(WaveActiveBitAnd, uint32_t); HLK_WAVEOP_TEST(WaveActiveBitOr, uint32_t); HLK_WAVEOP_TEST(WaveActiveBitXor, uint32_t); @@ -2384,6 +2472,11 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, uint32_t); HLK_WAVEOP_TEST(WavePrefixSum, uint32_t); HLK_WAVEOP_TEST(WavePrefixProduct, uint32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint32_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint32_t); HLK_WAVEOP_TEST(WaveActiveSum, uint64_t); HLK_WAVEOP_TEST(WaveActiveMin, uint64_t); HLK_WAVEOP_TEST(WaveActiveMax, uint64_t); @@ -2396,6 +2489,11 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, uint64_t); HLK_WAVEOP_TEST(WavePrefixSum, uint64_t); HLK_WAVEOP_TEST(WavePrefixProduct, uint64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint64_t); + HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint64_t); HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveMin, HLSLHalf_t); @@ -2406,6 +2504,8 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLHalf_t); HLK_WAVEOP_TEST(WavePrefixSum, HLSLHalf_t); HLK_WAVEOP_TEST(WavePrefixProduct, HLSLHalf_t); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, HLSLHalf_t); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveSum, float); HLK_WAVEOP_TEST(WaveActiveMin, float); HLK_WAVEOP_TEST(WaveActiveMax, float); @@ -2415,6 +2515,8 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, float); HLK_WAVEOP_TEST(WavePrefixSum, float); HLK_WAVEOP_TEST(WavePrefixProduct, float); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, float); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, float); HLK_WAVEOP_TEST(WaveActiveSum, double); HLK_WAVEOP_TEST(WaveActiveMin, double); HLK_WAVEOP_TEST(WaveActiveMax, double); @@ -2424,6 +2526,8 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveReadLaneFirst, double); HLK_WAVEOP_TEST(WavePrefixSum, double); HLK_WAVEOP_TEST(WavePrefixProduct, double); + HLK_WAVEOP_TEST(WaveMultiPrefixSum, double); + HLK_WAVEOP_TEST(WaveMultiPrefixProduct, double); private: bool Initialized = false; diff --git a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml index 5bc9f7118e..fe238bdfed 100644 --- a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml +++ b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml @@ -4215,6 +4215,181 @@ void MSMain(uint GID : SV_GroupIndex, } #endif + #ifdef FUNC_WAVE_MULTI_PREFIX_SUM + void TestWaveMultiPrefixSum(vector Vector) + { + uint Key = (WaveGetLaneIndex() == 1 || WaveGetLaneIndex() == 2 || WaveGetLaneIndex() == 3) ? 1u : 0u; + + // Two groups. Lanes 1,2,3 in one group (Key=1), Lanes 0,(4..N) in + // other (Key=0). + uint4 Mask = WaveMatch(Key); + + if(WaveGetLaneIndex() == 0) + { + // Lane 0 isn't in the mask. Shove in a value to make sure it + // doesn't constribute to the result. + Vector = 1; + } + + if(WaveGetLaneIndex() >= 3) + { + // Lane 3 is the last lane in the mask. We want to make sure + // it doesn't contribute to the result as this is a prefix op. + Vector = 10; + } + + Vector = WaveMultiPrefixSum(Vector, Mask); + if(WaveGetLaneIndex() == 3) + { + // Lane 3 is the last lane in the mask that we care about. Store the + // result from it. + g_OutputVector.Store< vector >(0, Vector); + } + } + #endif + + #ifdef FUNC_WAVE_MULTI_PREFIX_PRODUCT + void TestWaveMultiPrefixProduct(vector Vector) + { + uint Key = (WaveGetLaneIndex() == 1 || WaveGetLaneIndex() == 2 || WaveGetLaneIndex() == 3) ? 1u : 0u; + + // Two groups. Lanes 1,2,3 in one group (Key=1), Lanes 0,(4..N) in + // other (Key=0). + uint4 Mask = WaveMatch(Key); + + if(WaveGetLaneIndex() == 0) + { + // Lane 0 isn't in the mask. Shove in a value to make sure it + // doesn't constribute to the result. + Vector = 4; + } + + if(WaveGetLaneIndex() == 3) + { + // Lane 3 is the last lane in the mask. We want to make sure + // it doesn't contribute to the result as this is a prefix op. + Vector = 10; + } + + Vector = WaveMultiPrefixProduct(Vector, Mask); + if(WaveGetLaneIndex() == 3) + { + // Lane 3 is the last lane in the mask. Store the result from it. + g_OutputVector.Store< vector >(0, Vector); + } + } + #endif + + #ifdef FUNC_WAVE_MULTI_PREFIX_BIT_AND + void TestWaveMultiPrefixBitAnd(vector Vector) + { + uint Key = (WaveGetLaneIndex() == 1 || WaveGetLaneIndex() == 2 || WaveGetLaneIndex() == 3) ? 1u : 0u; + + // Two groups. Lanes 1,2,3 in one group (Key=1), Lanes 0,(4..N) in + // other (Key=0). + uint4 Mask = WaveMatch(Key); + + if(WaveGetLaneIndex() == 0 || WaveGetLaneIndex() == 3) + { + // Clear LSB on lane 0 and lane 3. Lane 0 isn't in the mask so + // shouldn't participate. Lane 3 is the output lane for this prefix + // op, so we set distinctive bits to verify it doesn't affect its own result. + Vector = Vector & ~((OUT_TYPE)0x1); + } + else // Lanes 1,2 (active contributors to the prefix operation) + { + // Keep only bits 1 and 2 (0x6 = 0b0110) to create predictable AND patterns + Vector = (Vector & ((OUT_TYPE)0x6)); + } + + Vector = WaveMultiPrefixBitAnd(Vector, Mask); + if(WaveGetLaneIndex() == 3) + { + // Lane 3 is the last lane in the mask. Store the result from it. + g_OutputVector.Store< vector >(0, Vector); + } + } + #endif + + #ifdef FUNC_WAVE_MULTI_PREFIX_BIT_OR + void TestWaveMultiPrefixBitOr(vector Vector) + { + uint Key = (WaveGetLaneIndex() == 1 || WaveGetLaneIndex() == 2 || WaveGetLaneIndex() == 3) ? 1u : 0u; + + // Two groups. Lanes 1,2,3 in one group (Key=1), Lanes 0,(4..N) in + // other (Key=0). + uint4 Mask = WaveMatch(Key); + + if(WaveGetLaneIndex() == 1 || WaveGetLaneIndex() == 2 || WaveGetLaneIndex() == 3) + { + // Lanes 1,2,3 (inside the mask): Clear bit 1 (0x2) to create + // predictable OR patterns + Vector = Vector & ~((OUT_TYPE)0x2); + } + else + { + // Lane 0 (outside the mask): Set bit 1 to verify this lane + // doesn't contribute to the result + Vector = Vector | ((OUT_TYPE)0x2); + } + + if(WaveGetLaneIndex() == 3) + { + // Lane 3 is the output lane: Set all bits to verify it doesn't + // affect its own prefix result (since prefix excludes current lane) + Vector = Vector | ~((OUT_TYPE)0x0); + } + + Vector = WaveMultiPrefixBitOr(Vector, Mask); + if(WaveGetLaneIndex() == 3) + { + // Lane 3 is the last lane in the mask. Store the result from it. + g_OutputVector.Store< vector >(0, Vector); + } + } + #endif + + #ifdef FUNC_WAVE_MULTI_PREFIX_BIT_XOR + void TestWaveMultiPrefixBitXor(vector Vector) + { + uint Key = (WaveGetLaneIndex() == 1 || WaveGetLaneIndex() == 2 || WaveGetLaneIndex() == 3) ? 1u : 0u; + + // Two groups. Lanes 1,2,3 in one group (Key=1), Lanes 0,(4..N) in + // other (Key=0). + uint4 Mask = WaveMatch(Key); + + if(WaveGetLaneIndex() == 0) + { + // Lane 0 is not in the mask, so these values should have no effect + // on the prefix result. Set to 0 to verify exclusion. + Vector = 0; + } + + if(WaveGetLaneIndex() == 2) + { + // Lane 2: Create a specific pattern for XOR testing. + // Zero the lower half of the vector to create predictable XOR results. + [unroll] + for(uint I = 0; I < NUM/2; ++I) + { + Vector[I] = 0; + } + + // Also zero the last element to test edge cases + Vector[NUM - 1] = 0; + } + // Lane 1 and 3: Keep original input values + // Lane 3 will store the result (lane 1 XOR lane 2 prefix) + + Vector = WaveMultiPrefixBitXor(Vector, Mask); + if(WaveGetLaneIndex() == 3) + { + // Store result from lane 3 (last lane in mask) + g_OutputVector.Store< vector >(0, Vector); + } + } + #endif + #ifdef FUNC_TEST_SELECT vector TestSelect(vector Vector1, vector Vector2,