diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index d4a3e34a43c53..8a41388bd2244 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5072,6 +5072,12 @@ def HLSLWaveGetLaneCount : LangBuiltin<"HLSL_LANG"> { let Prototype = "unsigned int()"; } +def HLSLWavePrefixSum : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_wave_prefix_sum"]; + let Attributes = [NoThrow, Const]; + let Prototype = "void(...)"; +} + def HLSLClamp : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_elementwise_clamp"]; let Attributes = [NoThrow, Const, CustomTypeChecking]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index b6928ce7d9c44..f79433e755f94 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -240,61 +240,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) { return RT.getFirstBitUHighIntrinsic(); } -// Return wave active sum that corresponds to the QT scalar type -static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { - switch (Arch) { - case llvm::Triple::spirv: - return Intrinsic::spv_wave_reduce_sum; - case llvm::Triple::dxil: { - if (QT->isUnsignedIntegerType()) - return Intrinsic::dx_wave_reduce_usum; - return Intrinsic::dx_wave_reduce_sum; - } - default: - llvm_unreachable("Intrinsic WaveActiveSum" - " not supported by target architecture"); - } -} - -// Return wave active max that corresponds to the QT scalar type -static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { - switch (Arch) { - case llvm::Triple::spirv: - if (QT->isUnsignedIntegerType()) - return Intrinsic::spv_wave_reduce_umax; - return Intrinsic::spv_wave_reduce_max; - case llvm::Triple::dxil: { - if (QT->isUnsignedIntegerType()) - return Intrinsic::dx_wave_reduce_umax; - return Intrinsic::dx_wave_reduce_max; - } - default: - llvm_unreachable("Intrinsic WaveActiveMax" - " not supported by target architecture"); - } -} - -// Return wave active min that corresponds to the QT scalar type -static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { - switch (Arch) { - case llvm::Triple::spirv: - if (QT->isUnsignedIntegerType()) - return Intrinsic::spv_wave_reduce_umin; - return Intrinsic::spv_wave_reduce_min; - case llvm::Triple::dxil: { - if (QT->isUnsignedIntegerType()) - return Intrinsic::dx_wave_reduce_umin; - return Intrinsic::dx_wave_reduce_min; - } - default: - llvm_unreachable("Intrinsic WaveActiveMin" - " not supported by target architecture"); - } -} - // Returns the mangled name for a builtin function that the SPIR-V backend // will expand into a spec Constant. static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType, @@ -794,33 +739,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, ArrayRef{OpExpr}); } case Builtin::BI__builtin_hlsl_wave_active_sum: { - // Due to the use of variadic arguments, explicitly retreive argument + // Due to the use of variadic arguments, explicitly retrieve argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); - Intrinsic::ID IID = getWaveActiveSumIntrinsic( - getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), - E->getArg(0)->getType()); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveSumIntrinsic( + QT->isUnsignedIntegerType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), ArrayRef{OpExpr}, "hlsl.wave.active.sum"); } case Builtin::BI__builtin_hlsl_wave_active_max: { - // Due to the use of variadic arguments, explicitly retreive argument + // Due to the use of variadic arguments, explicitly retrieve argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); - Intrinsic::ID IID = getWaveActiveMaxIntrinsic( - getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), - E->getArg(0)->getType()); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMaxIntrinsic( + QT->isUnsignedIntegerType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), ArrayRef{OpExpr}, "hlsl.wave.active.max"); } case Builtin::BI__builtin_hlsl_wave_active_min: { - // Due to the use of variadic arguments, explicitly retreive argument + // Due to the use of variadic arguments, explicitly retrieve argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); - Intrinsic::ID IID = getWaveActiveMinIntrinsic( - getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), - E->getArg(0)->getType()); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMinIntrinsic( + QT->isUnsignedIntegerType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), @@ -864,6 +809,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, {OpExpr->getType()}), ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane"); } + case Builtin::BI__builtin_hlsl_wave_prefix_sum: { + Value *OpExpr = EmitScalarExpr(E->getArg(0)); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic( + QT->isUnsignedIntegerType()); + return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( + &CGM.getModule(), IID, {OpExpr->getType()}), + ArrayRef{OpExpr}, "hlsl.wave.prefix.sum"); + } case Builtin::BI__builtin_hlsl_elementwise_sign: { auto *Arg0 = E->getArg(0); Value *Op0 = EmitScalarExpr(Arg0); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 4bdba9b3da502..f9b1928ac7c45 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -276,6 +276,29 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() { return CGM.getTarget().getTriple().getArch(); } +llvm::Intrinsic::ID +CGHLSLRuntime::getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID) { + switch (IID) { + // DXIL intrinsics + case Intrinsic::dx_wave_reduce_sum: + return Intrinsic::dx_wave_reduce_usum; + case Intrinsic::dx_wave_reduce_max: + return Intrinsic::dx_wave_reduce_umax; + case Intrinsic::dx_wave_reduce_min: + return Intrinsic::dx_wave_reduce_umin; + case Intrinsic::dx_wave_prefix_sum: + return Intrinsic::dx_wave_prefix_usum; + + // SPIR-V intrinsics + case Intrinsic::spv_wave_reduce_max: + return Intrinsic::spv_wave_reduce_umax; + case Intrinsic::spv_wave_reduce_min: + return Intrinsic::spv_wave_reduce_umin; + default: + return IID; + } +} + // Emits constant global variables for buffer constants declarations // and creates metadata linking the constant globals with the buffer global. void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl, diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index 488a322ca7569..671c7d434edbb 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -49,6 +49,33 @@ } \ } +// A function generator macro for picking the right intrinsic for the target +// backend given IsUnsigned boolean condition. If IsUnsigned == true, it calls +// getUnsignedIntrinsicVariant(IID) to retrieve the unsigned variant of the +// intrinsic else the regular intrinsic is returned. (NOTE: +// getUnsignedIntrinsicVariant returns IID itself if there is no unsigned +// variant). +#define GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(FunctionName, \ + IntrinsicPostfix) \ + llvm::Intrinsic::ID get##FunctionName##Intrinsic(bool IsUnsigned) { \ + llvm::Triple::ArchType Arch = getArch(); \ + switch (Arch) { \ + case llvm::Triple::dxil: { \ + static constexpr llvm::Intrinsic::ID IID = \ + llvm::Intrinsic::dx_##IntrinsicPostfix; \ + return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \ + } \ + case llvm::Triple::spirv: { \ + static constexpr llvm::Intrinsic::ID IID = \ + llvm::Intrinsic::spv_##IntrinsicPostfix; \ + return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \ + } \ + default: \ + llvm_unreachable("Intrinsic " #IntrinsicPostfix \ + " not supported by target architecture"); \ + } \ + } + using ResourceClass = llvm::dxil::ResourceClass; namespace llvm { @@ -141,9 +168,17 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveSum, + wave_reduce_sum) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMax, + wave_reduce_max) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMin, + wave_reduce_min) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WavePrefixSum, + wave_prefix_sum) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow) @@ -246,6 +281,10 @@ class CGHLSLRuntime { llvm::Triple::ArchType getArch(); + // Returns the unsigned variant of the given intrinsic ID if possible, + // otherwise, the original intrinsic ID is returned. + llvm::Intrinsic::ID getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID); + llvm::DenseMap LayoutTypes; unsigned SPIRVLastAssignedInputSemanticLocation = 0; }; diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index 2e2703de18cb1..d3b9af9695016 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2802,6 +2802,105 @@ __attribute__((convergent)) double3 WaveActiveSum(double3); _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum) __attribute__((convergent)) double4 WaveActiveSum(double4); +//===----------------------------------------------------------------------===// +// WavePrefixSum builtins +//===----------------------------------------------------------------------===// + +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half WavePrefixSum(half); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half2 WavePrefixSum(half2); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half3 WavePrefixSum(half3); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half4 WavePrefixSum(half4); + +#ifdef __HLSL_ENABLE_16_BIT +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t WavePrefixSum(int16_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t2 WavePrefixSum(int16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t3 WavePrefixSum(int16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t4 WavePrefixSum(int16_t4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t WavePrefixSum(uint16_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t2 WavePrefixSum(uint16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t3 WavePrefixSum(uint16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t4 WavePrefixSum(uint16_t4); +#endif + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int WavePrefixSum(int); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int2 WavePrefixSum(int2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int3 WavePrefixSum(int3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int4 WavePrefixSum(int4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint WavePrefixSum(uint); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint2 WavePrefixSum(uint2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint3 WavePrefixSum(uint3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint4 WavePrefixSum(uint4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t WavePrefixSum(int64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t2 WavePrefixSum(int64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t3 WavePrefixSum(int64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t4 WavePrefixSum(int64_t4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t WavePrefixSum(uint64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t2 WavePrefixSum(uint64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t3 WavePrefixSum(uint64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t4 WavePrefixSum(uint64_t4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float WavePrefixSum(float); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float2 WavePrefixSum(float2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float3 WavePrefixSum(float3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float4 WavePrefixSum(float4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double WavePrefixSum(double); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double2 WavePrefixSum(double2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double3 WavePrefixSum(double3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double4 WavePrefixSum(double4); + //===----------------------------------------------------------------------===// // sign builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index e95fe16e6cb6c..d8abdb9b75d18 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2892,10 +2892,13 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, return false; } -static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { +// Check that the argument is not a bool or vector +// Returns true on error +static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall, + unsigned ArgIndex) { QualType BoolType = S->getASTContext().BoolTy; - assert(TheCall->getNumArgs() >= 1); - QualType ArgType = TheCall->getArg(0)->getType(); + assert(ArgIndex < TheCall->getNumArgs()); + QualType ArgType = TheCall->getArg(ArgIndex)->getType(); auto *VTy = ArgType->getAs(); // is the bool or vector if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) || @@ -2909,6 +2912,18 @@ static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { return false; } +static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { + if (CheckNotBoolScalarOrVector(S, TheCall, 0)) + return true; + return false; +} + +static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) { + if (CheckNotBoolScalarOrVector(S, TheCall, 0)) + return true; + return false; +} + static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { assert(TheCall->getNumArgs() == 3); Expr *Arg1 = TheCall->getArg(1); @@ -3371,6 +3386,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } + case Builtin::BI__builtin_hlsl_wave_prefix_sum: { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + + // Ensure input expr type is a scalar/vector and the same as the return type + if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) + return true; + if (CheckWavePrefix(&SemaRef, TheCall)) + return true; + ExprResult Expr = TheCall->getArg(0); + QualType ArgTyExpr = Expr.get()->getType(); + TheCall->setType(ArgTyExpr); + break; + } case Builtin::BI__builtin_hlsl_elementwise_splitdouble: { if (SemaRef.checkArgCount(TheCall, 3)) return true; diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl new file mode 100644 index 0000000000000..f22aa69ba45d5 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl @@ -0,0 +1,46 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +// CHECK-LABEL: test_int +int test_int(int expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i32([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WavePrefixSum(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i32([[TY]]) #[[#attr:]] + +// CHECK-LABEL: test_uint64_t +uint64_t test_uint64_t(uint64_t expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i64([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.usum.i64([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WavePrefixSum(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.usum.i64([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i64([[TY]]) #[[#attr:]] + +// Test basic lowering to runtime function call with array and float value. + +// CHECK-LABEL: test_floatv4 +float4 test_floatv4(float4 expr) { + // CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.prefix.sum.v4f32([[TY1]] %[[#]] + // CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.prefix.sum.v4f32([[TY1]] %[[#]]) + // CHECK: ret [[TY1]] %[[RET1]] + return WavePrefixSum(expr); +} + +// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.prefix.sum.v4f32([[TY1]]) #[[#attr]] +// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.prefix.sum.v4f32([[TY1]]) #[[#attr]] + +// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} + diff --git a/clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl new file mode 100644 index 0000000000000..1e575c94e67a5 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl @@ -0,0 +1,28 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +int test_too_few_arg() { + return __builtin_hlsl_wave_prefix_sum(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +float2 test_too_many_arg(float2 p0) { + return __builtin_hlsl_wave_prefix_sum(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +bool test_expr_bool_type_check(bool p0) { + return __builtin_hlsl_wave_prefix_sum(p0); + // expected-error@-1 {{invalid operand of type 'bool'}} +} + +bool2 test_expr_bool_vec_type_check(bool2 p0) { + return __builtin_hlsl_wave_prefix_sum(p0); + // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector')}} +} + +struct S { float f; }; + +S test_expr_struct_type_check(S p0) { + return __builtin_hlsl_wave_prefix_sum(p0); + // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index d7db935ee07f1..b613d96275d03 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -164,6 +164,8 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_get_lane_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>; +def int_dx_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; +def int_dx_wave_prefix_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index f39c6cda2c579..9c05cc54809b4 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -129,6 +129,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_get_lane_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>; + def int_spv_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>; def int_spv_group_memory_barrier_with_group_sync diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 67437f6969b27..6075c9ba23931 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1079,6 +1079,30 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> { let attributes = [Attributes]; } +def WavePrefixOp : DXILOp<121, wavePrefixOp> { + let Doc = "returns partial result of the computation in the corresponding lane"; + let intrinsics = [ + IntrinSelect, IntrinArgI8, + IntrinArgI8 + ]>, + IntrinSelect, IntrinArgI8, + IntrinArgI8 + ]>, + ]; + + let arguments = [OverloadTy, Int8Ty, Int8Ty]; + let result = OverloadTy; + let overloads = [ + Overloads + ]; + let stages = [Stages]; + let attributes = [Attributes]; +} + def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> { let Doc = "returns the float16 stored in the low-half of the uint converted " "to a float"; diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index ce6e8121b9d94..8a2ed48b61557 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -96,6 +96,8 @@ static bool checkWaveOps(Intrinsic::ID IID) { case Intrinsic::dx_wave_reduce_umax: case Intrinsic::dx_wave_reduce_min: case Intrinsic::dx_wave_reduce_umin: + // Wave Prefix Op Variants + case Intrinsic::dx_wave_prefix_sum: return true; } } diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index 6cacbf6564db2..7ae814648cadc 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -59,9 +59,11 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( case Intrinsic::dx_wave_reduce_max: case Intrinsic::dx_wave_reduce_min: case Intrinsic::dx_wave_reduce_sum: + case Intrinsic::dx_wave_prefix_sum: case Intrinsic::dx_wave_reduce_umax: case Intrinsic::dx_wave_reduce_umin: case Intrinsic::dx_wave_reduce_usum: + case Intrinsic::dx_wave_prefix_usum: case Intrinsic::dx_imad: case Intrinsic::dx_umad: return true; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index fc87288a4a212..68e185cdd301f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -225,8 +225,8 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectWaveReduceMin(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsUnsigned) const; - bool selectWaveReduceSum(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const; + bool selectWaveSumWithGroupOp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, int64_t GroupOp) const; bool selectConst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -2452,7 +2452,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, report_fatal_error("Input Type could not be determined."); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - // Retreive the operation to use based on input type + // Retrieve the operation to use based on input type bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); auto IntegerOpcodeType = IsUnsigned ? SPIRV::OpGroupNonUniformUMax : SPIRV::OpGroupNonUniformSMax; @@ -2481,7 +2481,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMin(Register ResVReg, report_fatal_error("Input Type could not be determined."); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - // Retreive the operation to use based on input type + // Retrieve the operation to use based on input type bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); auto IntegerOpcodeType = IsUnsigned ? SPIRV::OpGroupNonUniformUMin : SPIRV::OpGroupNonUniformSMin; @@ -2496,9 +2496,9 @@ bool SPIRVInstructionSelector::selectWaveReduceMin(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } -bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I) const { +bool SPIRVInstructionSelector::selectWaveSumWithGroupOp( + Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + int64_t GroupOp) const { assert(I.getNumOperands() == 3); assert(I.getOperand(2).isReg()); MachineBasicBlock &BB = *I.getParent(); @@ -2509,7 +2509,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, report_fatal_error("Input Type could not be determined."); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - // Retreive the operation to use based on input type + // Retrieve the operation to use based on input type bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); auto Opcode = IsFloatTy ? SPIRV::OpGroupNonUniformFAdd : SPIRV::OpGroupNonUniformIAdd; @@ -2518,7 +2518,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, .addUse(GR.getSPIRVTypeID(ResType)) .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, !STI.isShader())) - .addImm(SPIRV::GroupOperation::Reduce) + .addImm(GroupOp) .addUse(I.getOperand(2).getReg()); } @@ -3485,10 +3485,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_wave_reduce_min: return selectWaveReduceMin(ResVReg, ResType, I, /*IsUnsigned*/ false); case Intrinsic::spv_wave_reduce_sum: - return selectWaveReduceSum(ResVReg, ResType, I); + return selectWaveSumWithGroupOp(ResVReg, ResType, I, + SPIRV::GroupOperation::Reduce); case Intrinsic::spv_wave_readlane: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformShuffle); + case Intrinsic::spv_wave_prefix_sum: + return selectWaveSumWithGroupOp(ResVReg, ResType, I, + SPIRV::GroupOperation::ExclusiveScan); case Intrinsic::spv_step: return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step); case Intrinsic::spv_radians: diff --git a/llvm/test/CodeGen/DirectX/WavePrefixSum.ll b/llvm/test/CodeGen/DirectX/WavePrefixSum.ll new file mode 100644 index 0000000000000..ed8c2b2b85465 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WavePrefixSum.ll @@ -0,0 +1,143 @@ +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s + +; Test that for scalar values, WavePrefixSum maps down to the DirectX op + +define noundef half @wave_prefix_sum_half(half noundef %expr) { +entry: +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr, i8 0, i8 0) + %ret = call half @llvm.dx.wave.prefix.sum.f16(half %expr) + ret half %ret +} + +define noundef float @wave_prefix_sum_float(float noundef %expr) { +entry: +; CHECK: call float @dx.op.wavePrefixOp.f32(i32 121, float %expr, i8 0, i8 0) + %ret = call float @llvm.dx.wave.prefix.sum.f32(float %expr) + ret float %ret +} + +define noundef double @wave_prefix_sum_double(double noundef %expr) { +entry: +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr, i8 0, i8 0) + %ret = call double @llvm.dx.wave.prefix.sum.f64(double %expr) + ret double %ret +} + +define noundef i16 @wave_prefix_sum_i16(i16 noundef %expr) { +entry: +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr, i8 0, i8 0) + %ret = call i16 @llvm.dx.wave.prefix.sum.i16(i16 %expr) + ret i16 %ret +} + +define noundef i32 @wave_prefix_sum_i32(i32 noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr, i8 0, i8 0) + %ret = call i32 @llvm.dx.wave.prefix.sum.i32(i32 %expr) + ret i32 %ret +} + +define noundef i64 @wave_prefix_sum_i64(i64 noundef %expr) { +entry: +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr, i8 0, i8 0) + %ret = call i64 @llvm.dx.wave.prefix.sum.i64(i64 %expr) + ret i64 %ret +} + +define noundef i16 @wave_prefix_usum_i16(i16 noundef %expr) { +entry: +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr, i8 0, i8 1) + %ret = call i16 @llvm.dx.wave.prefix.usum.i16(i16 %expr) + ret i16 %ret +} + +define noundef i32 @wave_prefix_usum_i32(i32 noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr, i8 0, i8 1) + %ret = call i32 @llvm.dx.wave.prefix.usum.i32(i32 %expr) + ret i32 %ret +} + +define noundef i64 @wave_prefix_usum_i64(i64 noundef %expr) { +entry: +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr, i8 0, i8 1) + %ret = call i64 @llvm.dx.wave.prefix.usum.i64(i64 %expr) + ret i64 %ret +} + +declare half @llvm.dx.wave.prefix.sum.f16(half) +declare float @llvm.dx.wave.prefix.sum.f32(float) +declare double @llvm.dx.wave.prefix.sum.f64(double) + +declare i16 @llvm.dx.wave.prefix.sum.i16(i16) +declare i32 @llvm.dx.wave.prefix.sum.i32(i32) +declare i64 @llvm.dx.wave.prefix.sum.i64(i64) + +declare i16 @llvm.dx.wave.prefix.usum.i16(i16) +declare i32 @llvm.dx.wave.prefix.usum.i32(i32) +declare i64 @llvm.dx.wave.prefix.usum.i64(i64) + +; Test that for vector values, WavePrefixSum scalarizes and maps down to the +; DirectX op + +define noundef <2 x half> @wave_prefix_sum_v2half(<2 x half> noundef %expr) { +entry: +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i0, i8 0, i8 0) +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i1, i8 0, i8 0) + %ret = call <2 x half> @llvm.dx.wave.prefix.sum.v2f16(<2 x half> %expr) + ret <2 x half> %ret +} + +define noundef <3 x i32> @wave_prefix_sum_v3i32(<3 x i32> noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i0, i8 0, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i1, i8 0, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i2, i8 0, i8 0) + %ret = call <3 x i32> @llvm.dx.wave.prefix.sum.v3i32(<3 x i32> %expr) + ret <3 x i32> %ret +} + +define noundef <4 x double> @wave_prefix_sum_v4f64(<4 x double> noundef %expr) { +entry: +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i0, i8 0, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i1, i8 0, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i2, i8 0, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i3, i8 0, i8 0) + %ret = call <4 x double> @llvm.dx.wave.prefix.sum.v464(<4 x double> %expr) + ret <4 x double> %ret +} + +declare <2 x half> @llvm.dx.wave.prefix.sum.v2f16(<2 x half>) +declare <3 x i32> @llvm.dx.wave.prefix.sum.v3i32(<3 x i32>) +declare <4 x double> @llvm.dx.wave.prefix.sum.v4f64(<4 x double>) + +define noundef <2 x i16> @wave_prefix_usum_v2i16(<2 x i16> noundef %expr) { +entry: +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr.i0, i8 0, i8 1) +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr.i1, i8 0, i8 1) + %ret = call <2 x i16> @llvm.dx.wave.prefix.usum.v2f16(<2 x i16> %expr) + ret <2 x i16> %ret +} + +define noundef <3 x i32> @wave_prefix_usum_v3i32(<3 x i32> noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i0, i8 0, i8 1) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i1, i8 0, i8 1) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i2, i8 0, i8 1) + %ret = call <3 x i32> @llvm.dx.wave.prefix.usum.v3i32(<3 x i32> %expr) + ret <3 x i32> %ret +} + +define noundef <4 x i64> @wave_prefix_usum_v4f64(<4 x i64> noundef %expr) { +entry: +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i0, i8 0, i8 1) +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i1, i8 0, i8 1) +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i2, i8 0, i8 1) +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i3, i8 0, i8 1) + %ret = call <4 x i64> @llvm.dx.wave.prefix.usum.v464(<4 x i64> %expr) + ret <4 x i64> %ret +} + +declare <2 x i16> @llvm.dx.wave.prefix.usum.v2f16(<2 x i16>) +declare <3 x i32> @llvm.dx.wave.prefix.usum.v3i32(<3 x i32>) +declare <4 x i64> @llvm.dx.wave.prefix.usum.v4f64(<4 x i64>) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll new file mode 100644 index 0000000000000..5fb82fd9ebf19 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll @@ -0,0 +1,41 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %} + +; Test lowering to spir-v backend for various types and scalar/vector + +; CHECK-DAG: %[[#f16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v4_half:]] = OpTypeVector %[[#f16]] 4 +; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 + +; CHECK-LABEL: Begin function test_float +; CHECK: %[[#fexpr:]] = OpFunctionParameter %[[#f32]] +define float @test_float(float %fexpr) { +entry: +; CHECK: %[[#fret:]] = OpGroupNonUniformFAdd %[[#f32]] %[[#scope]] ExclusiveScan %[[#fexpr]] + %0 = call float @llvm.spv.wave.prefix.sum.f32(float %fexpr) + ret float %0 +} + +; CHECK-LABEL: Begin function test_int +; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]] +define i32 @test_int(i32 %iexpr) { +entry: +; CHECK: %[[#iret:]] = OpGroupNonUniformIAdd %[[#uint]] %[[#scope]] ExclusiveScan %[[#iexpr]] + %0 = call i32 @llvm.spv.wave.prefix.sum.i32(i32 %iexpr) + ret i32 %0 +} + +; CHECK-LABEL: Begin function test_vhalf +; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]] +define <4 x half> @test_vhalf(<4 x half> %vbexpr) { +entry: +; CHECK: %[[#vhalfret:]] = OpGroupNonUniformFAdd %[[#v4_half]] %[[#scope]] ExclusiveScan %[[#vbexpr]] + %0 = call <4 x half> @llvm.spv.wave.prefix.sum.v4half(<4 x half> %vbexpr) + ret <4 x half> %0 +} + +declare float @llvm.spv.wave.prefix.sum.f32(float) +declare i32 @llvm.spv.wave.prefix.sum.i32(i32) +declare <4 x half> @llvm.spv.wave.prefix.sum.v4half(<4 x half>)