Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5011,6 +5011,12 @@ def HLSLWaveActiveMax : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void (...)";
}

def HLSLWaveActiveMin : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_min"];
let Attributes = [NoThrow, Const];
let Prototype = "void (...)";
}

def HLSLWaveActiveSum : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_sum"];
let Attributes = [NoThrow, Const];
Expand Down
32 changes: 31 additions & 1 deletion clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
}
}

// Return wave active sum that corresponds to the QT scalar type
// Return wave active max that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
Expand All @@ -215,6 +215,25 @@ static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
}
}

// Return wave active min that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
if (QT->isUnsignedIntegerType())
return Intrinsic::spv_wave_reduce_umin;
return Intrinsic::spv_wave_reduce_min;
case llvm::Triple::dxil: {
if (QT->isUnsignedIntegerType())
return Intrinsic::dx_wave_reduce_umin;
return Intrinsic::dx_wave_reduce_min;
}
default:
llvm_unreachable("Intrinsic WaveActiveMin"
" not supported by target architecture");
}
}

// Returns the mangled name for a builtin function that the SPIR-V backend
// will expand into a spec Constant.
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
Expand Down Expand Up @@ -719,6 +738,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.active.max");
}
case Builtin::BI__builtin_hlsl_wave_active_min: {
// Due to the use of variadic arguments, explicitly retreive argument
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID IID = getWaveActiveMinIntrinsic(
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
E->getArg(0)->getType());

return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.active.min");
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
// defined in SPIRVBuiltins.td. So instead we manually get the matching name
Expand Down
123 changes: 123 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2597,6 +2597,129 @@ __attribute__((convergent)) double3 WaveActiveMax(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_max)
__attribute__((convergent)) double4 WaveActiveMax(double4);

//===----------------------------------------------------------------------===//
// WaveActiveMin builtins
//===----------------------------------------------------------------------===//

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) half WaveActiveMin(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) half2 WaveActiveMin(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) half3 WaveActiveMin(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) half4 WaveActiveMin(half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int16_t WaveActiveMin(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int16_t2 WaveActiveMin(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int16_t3 WaveActiveMin(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int16_t4 WaveActiveMin(int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint16_t WaveActiveMin(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint16_t2 WaveActiveMin(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint16_t3 WaveActiveMin(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint16_t4 WaveActiveMin(uint16_t4);
#endif

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int WaveActiveMin(int);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int2 WaveActiveMin(int2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int3 WaveActiveMin(int3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int4 WaveActiveMin(int4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint WaveActiveMin(uint);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint2 WaveActiveMin(uint2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint3 WaveActiveMin(uint3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint4 WaveActiveMin(uint4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int64_t WaveActiveMin(int64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int64_t2 WaveActiveMin(int64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int64_t3 WaveActiveMin(int64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) int64_t4 WaveActiveMin(int64_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint64_t WaveActiveMin(uint64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint64_t2 WaveActiveMin(uint64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint64_t3 WaveActiveMin(uint64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) uint64_t4 WaveActiveMin(uint64_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) float WaveActiveMin(float);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) float2 WaveActiveMin(float2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) float3 WaveActiveMin(float3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) float4 WaveActiveMin(float4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) double WaveActiveMin(double);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) double2 WaveActiveMin(double2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) double3 WaveActiveMin(double3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_min)
__attribute__((convergent)) double4 WaveActiveMin(double4);

//===----------------------------------------------------------------------===//
// WaveActiveSum builtins
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3197,6 +3197,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
break;
}
case Builtin::BI__builtin_hlsl_wave_active_max:
case Builtin::BI__builtin_hlsl_wave_active_min:
case Builtin::BI__builtin_hlsl_wave_active_sum: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
Expand Down
46 changes: 46 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveMin.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_int
int test_int(int expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.min.i32([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.min.i32([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveMin(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.min.i32([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.min.i32([[TY]]) #[[#attr:]]

// CHECK-LABEL: test_uint64_t
uint64_t test_uint64_t(uint64_t expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.umin.i64([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.umin.i64([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveMin(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.umin.i64([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.umin.i64([[TY]]) #[[#attr:]]

// Test basic lowering to runtime function call with array and float value.

// CHECK-LABEL: test_floatv4
float4 test_floatv4(float4 expr) {
// CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.reduce.min.v4f32([[TY1]] %[[#]]
// CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.reduce.min.v4f32([[TY1]] %[[#]])
// CHECK: ret [[TY1]] %[[RET1]]
return WaveActiveMin(expr);
}

// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.min.v4f32([[TY1]]) #[[#attr]]
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.min.v4f32([[TY1]]) #[[#attr]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

29 changes: 29 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveMin.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

int test_too_few_arg() {
return __builtin_hlsl_wave_active_min();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

float2 test_too_many_arg(float2 p0) {
return __builtin_hlsl_wave_active_min(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool test_expr_bool_type_check(bool p0) {
return __builtin_hlsl_wave_active_min(p0);
// expected-error@-1 {{invalid operand of type 'bool'}}
}

bool2 test_expr_bool_vec_type_check(bool2 p0) {
return __builtin_hlsl_wave_active_min(p0);
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
}

struct S { float f; };

S test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_active_min(p0);
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
}

2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrCon
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umin : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umin : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
Expand All @@ -136,7 +138,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_sclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_spv_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;

// Create resource handle given the binding information. Returns a
// Create resource handle given the binding information. Returns a
// type appropriate for the kind of resource given the set id, binding id,
// array size of the binding, as well as an index and an indicator
// whether that index may be non-uniform.
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,16 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Max>,
IntrinArgI8<SignedOpKind_Unsigned>
]>,
IntrinSelect<int_dx_wave_reduce_min,
[
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Min>,
IntrinArgI8<SignedOpKind_Signed>
]>,
IntrinSelect<int_dx_wave_reduce_umin,
[
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Min>,
IntrinArgI8<SignedOpKind_Unsigned>
]>,
];

let arguments = [OverloadTy, Int8Ty, Int8Ty];
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ static bool checkWaveOps(Intrinsic::ID IID) {
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_umax:
case Intrinsic::dx_wave_reduce_min:
case Intrinsic::dx_wave_reduce_umin:
return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_splitdouble:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_min:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_umax:
case Intrinsic::dx_wave_reduce_umin:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_imad:
case Intrinsic::dx_umad:
Expand Down
Loading