Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5072,6 +5072,12 @@ def HLSLWaveGetLaneCount : LangBuiltin<"HLSL_LANG"> {
let Prototype = "unsigned int()";
}

def HLSLWavePrefixSum : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_prefix_sum"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,23 @@ static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
}
}

// Return wave prefix sum that corresponds to the QT scalar type
static Intrinsic::ID getWavePrefixSumIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
return Intrinsic::spv_wave_prefix_sum;
case llvm::Triple::dxil: {
if (QT->isUnsignedIntegerType())
return Intrinsic::dx_wave_prefix_usum;
return Intrinsic::dx_wave_prefix_sum;
}
default:
llvm_unreachable("Intrinsic WavePrefixSum"
" not supported by target architecture");
}
}

// Returns the mangled name for a builtin function that the SPIR-V backend
// will expand into a spec Constant.
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
Expand Down Expand Up @@ -864,6 +881,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
{OpExpr->getType()}),
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
}
case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID IID = getWavePrefixSumIntrinsic(
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
E->getArg(0)->getType());
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.prefix.sum");
}
case Builtin::BI__builtin_hlsl_elementwise_sign: {
auto *Arg0 = E->getArg(0);
Value *Op0 = EmitScalarExpr(Arg0);
Expand Down
99 changes: 99 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,105 @@ __attribute__((convergent)) double3 WaveActiveSum(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute__((convergent)) double4 WaveActiveSum(double4);

//===----------------------------------------------------------------------===//
// WavePrefixSum builtins
//===----------------------------------------------------------------------===//

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half WavePrefixSum(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half2 WavePrefixSum(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half3 WavePrefixSum(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half4 WavePrefixSum(half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t WavePrefixSum(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t2 WavePrefixSum(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t3 WavePrefixSum(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t4 WavePrefixSum(int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t WavePrefixSum(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t2 WavePrefixSum(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t3 WavePrefixSum(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t4 WavePrefixSum(uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int WavePrefixSum(int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int2 WavePrefixSum(int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int3 WavePrefixSum(int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int4 WavePrefixSum(int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint WavePrefixSum(uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint2 WavePrefixSum(uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint3 WavePrefixSum(uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint4 WavePrefixSum(uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t WavePrefixSum(int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t2 WavePrefixSum(int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t3 WavePrefixSum(int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t4 WavePrefixSum(int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t WavePrefixSum(uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t2 WavePrefixSum(uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t3 WavePrefixSum(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t4 WavePrefixSum(uint64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float WavePrefixSum(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float2 WavePrefixSum(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float3 WavePrefixSum(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float4 WavePrefixSum(float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double WavePrefixSum(double);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double2 WavePrefixSum(double2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double3 WavePrefixSum(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double4 WavePrefixSum(double4);

//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 32 additions & 3 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2892,10 +2892,13 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
return false;
}

static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
// Check that the argument is not a bool or vector<bool>
// Returns true on error
static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
unsigned ArgIndex) {
QualType BoolType = S->getASTContext().BoolTy;
assert(TheCall->getNumArgs() >= 1);
QualType ArgType = TheCall->getArg(0)->getType();
assert(ArgIndex < TheCall->getNumArgs());
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
auto *VTy = ArgType->getAs<VectorType>();
// is the bool or vector<bool>
if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
Expand All @@ -2909,6 +2912,18 @@ static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
return false;
}

static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
if (CheckNotBoolScalarOrVector(S, TheCall, 0))
return true;
return false;
}

static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
if (CheckNotBoolScalarOrVector(S, TheCall, 0))
return true;
return false;
}

static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
Expand Down Expand Up @@ -3371,6 +3386,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;

// Ensure input expr type is a scalar/vector and the same as the return type
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
return true;
if (CheckWavePrefix(&SemaRef, TheCall))
return true;
ExprResult Expr = TheCall->getArg(0);
QualType ArgTyExpr = Expr.get()->getType();
TheCall->setType(ArgTyExpr);
break;
}
case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
Expand Down
46 changes: 46 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_int
int test_int(int expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i32([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i32([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WavePrefixSum(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i32([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i32([[TY]]) #[[#attr:]]

// CHECK-LABEL: test_uint64_t
uint64_t test_uint64_t(uint64_t expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i64([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.usum.i64([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WavePrefixSum(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.usum.i64([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i64([[TY]]) #[[#attr:]]

// Test basic lowering to runtime function call with array and float value.

// CHECK-LABEL: test_floatv4
float4 test_floatv4(float4 expr) {
// CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.prefix.sum.v4f32([[TY1]] %[[#]]
// CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.prefix.sum.v4f32([[TY1]] %[[#]])
// CHECK: ret [[TY1]] %[[RET1]]
return WavePrefixSum(expr);
}

// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.prefix.sum.v4f32([[TY1]]) #[[#attr]]
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.prefix.sum.v4f32([[TY1]]) #[[#attr]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

28 changes: 28 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

int test_too_few_arg() {
return __builtin_hlsl_wave_prefix_sum();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

float2 test_too_many_arg(float2 p0) {
return __builtin_hlsl_wave_prefix_sum(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool test_expr_bool_type_check(bool p0) {
return __builtin_hlsl_wave_prefix_sum(p0);
// expected-error@-1 {{invalid operand of type 'bool'}}
}

bool2 test_expr_bool_vec_type_check(bool2 p0) {
return __builtin_hlsl_wave_prefix_sum(p0);
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
}

struct S { float f; };

S test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_prefix_sum(p0);
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
}
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_get_lane_count
: DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>;
def int_dx_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_prefix_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_get_lane_count
: DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>;
def int_spv_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_spv_group_memory_barrier_with_group_sync
Expand Down
24 changes: 24 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,30 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}

def WavePrefixOp : DXILOp<121, wavePrefixOp> {
let Doc = "returns partial result of the computation in the corresponding lane";
let intrinsics = [
IntrinSelect<int_dx_wave_prefix_sum,
[
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>,
IntrinArgI8<SignedOpKind_Signed>
]>,
IntrinSelect<int_dx_wave_prefix_usum,
[
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>,
IntrinArgI8<SignedOpKind_Unsigned>
]>,
];

let arguments = [OverloadTy, Int8Ty, Int8Ty];
let result = OverloadTy;
let overloads = [
Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int16Ty, Int32Ty, Int64Ty]>
];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, []>];
}

def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
let Doc = "returns the float16 stored in the low-half of the uint converted "
"to a float";
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ static bool checkWaveOps(Intrinsic::ID IID) {
case Intrinsic::dx_wave_reduce_umax:
case Intrinsic::dx_wave_reduce_min:
case Intrinsic::dx_wave_reduce_umin:
// Wave Prefix Op Variants
case Intrinsic::dx_wave_prefix_sum:
return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_min:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_prefix_sum:
case Intrinsic::dx_wave_reduce_umax:
case Intrinsic::dx_wave_reduce_umin:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_prefix_usum:
case Intrinsic::dx_imad:
case Intrinsic::dx_umad:
return true;
Expand Down
Loading