Skip to content

Commit

Permalink
[DXIL] exp, any, lerp, & rcp Intrinsic Lowering (#84526)
Browse files Browse the repository at this point in the history
This change implements lowering for #70076, #70100, #70072, & #70102 
`CGBuiltin.cpp` - - simplify `lerp` intrinsic
`IntrinsicsDirectX.td` - simplify `lerp` intrinsic
`SemaChecking.cpp` - remove unnecessary check
`DXILIntrinsicExpansion.*` - add intrinsic to instruction expansion
cases
`DXILOpLowering.cpp` - make sure `DXILIntrinsicExpansion` happens first
`DirectX.h` - changes to support new pass
`DirectXTargetMachine.cpp` - changes to support new pass

Why `any`, and `lerp` as instruction expansion just for DXIL?
- SPIR-V there is an
[OpAny](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpAny)
- SPIR-V has a GLSL lerp extension via
[Fmix](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#FMix)

Why `exp` instruction expansion?
- We have an `exp2` opcode and `exp` reuses that opcode. So instruction
expansion is a convenient way to do preprocessing.
- Further SPIR-V has a GLSL exp extension via
[Exp](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp)
and
[Exp2](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp2)

Why `rcp` as instruction expansion?
This one is a bit of the odd man out and might have to move to
`cgbuiltins` when we better understand SPIRV requirements. However I
included it because it seems like [fast math mode has an AllowRecip
flag](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_fp_fast_math_mode)
which lets you compute the reciprocal without performing the division.
We don't have that in DXIL so thought to include it.
  • Loading branch information
farzonl committed Mar 15, 2024
1 parent 58ef9be commit de1a97d
Show file tree
Hide file tree
Showing 22 changed files with 610 additions and 90 deletions.
5 changes: 5 additions & 0 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
bool isFloatingType() const; // C99 6.2.5p11 (real floating + complex)
bool isHalfType() const; // OpenCL 6.1.1.1, NEON (IEEE 754-2008 half)
bool isFloat16Type() const; // C11 extension ISO/IEC TS 18661
bool isFloat32Type() const;
bool isBFloat16Type() const;
bool isFloat128Type() const;
bool isIbm128Type() const;
Expand Down Expand Up @@ -7452,6 +7453,10 @@ inline bool Type::isFloat16Type() const {
return isSpecificBuiltinType(BuiltinType::Float16);
}

inline bool Type::isFloat32Type() const {
return isSpecificBuiltinType(BuiltinType::Float);
}

inline bool Type::isBFloat16Type() const {
return isSpecificBuiltinType(BuiltinType::BFloat16);
}
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4598,7 +4598,7 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {

def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

Expand Down
35 changes: 4 additions & 31 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18021,38 +18021,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
Value *S = EmitScalarExpr(E->getArg(2));
llvm::Type *Xty = X->getType();
llvm::Type *Yty = Y->getType();
llvm::Type *Sty = S->getType();
if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
if (Xty->isFloatingPointTy()) {
auto V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
}
llvm_unreachable("Scalar Lerp is only supported on floats.");
}
// A VectorSplat should have happened
assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
"Lerp of vector and scalar is not supported.");

[[maybe_unused]] auto *XVecTy =
E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *YVecTy =
E->getArg(1)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *SVecTy =
E->getArg(2)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
XVecTy->getNumElements() == SVecTy->getNumElements() &&
"Lerp requires vectors to be of the same size.");
assert(XVecTy->getElementType()->isRealFloatingType() &&
XVecTy->getElementType() == YVecTy->getElementType() &&
XVecTy->getElementType() == SVecTy->getElementType() &&
"Lerp requires float vectors to be of the same type.");
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
llvm_unreachable("lerp operand must have a float representation");
return Builder.CreateIntrinsic(
/*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
nullptr, "dx.lerp");
/*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
}
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Expand Down
51 changes: 37 additions & 14 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5234,10 +5234,6 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
TheCall->getArg(1)->getEndLoc());
retValue = true;
}

if (!retValue)
TheCall->setType(VecTyA->getElementType());

return retValue;
}
}
Expand All @@ -5251,11 +5247,12 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}

bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
QualType ExpectedType = S->Context.FloatTy;
bool CheckArgsTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
QualType PassedType = TheCall->getArg(i)->getType();
if (!PassedType->hasFloatingRepresentation()) {
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
Expand All @@ -5268,6 +5265,26 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
return false;
}

bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkAllFloatTypes);
}

bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->getAs<clang::VectorType>()->getElementType()
: PassedType;
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkFloatorHalf);
}

void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
QualType ReturnType) {
auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
Expand Down Expand Up @@ -5295,21 +5312,27 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_isinf: {
if (checkArgCount(*this, TheCall, 1))
return true;
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
return true;
SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
case Builtin::BI__builtin_hlsl_elementwise_rcp:
case Builtin::BI__builtin_hlsl_elementwise_frac: {
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
if (CheckFloatOrHalfRepresentations(this, TheCall))
return true;
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_isinf: {
if (CheckFloatOrHalfRepresentations(this, TheCall))
return true;
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
break;
}
case Builtin::BI__builtin_hlsl_lerp: {
Expand All @@ -5319,7 +5342,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (SemaBuiltinElementwiseTernaryMath(TheCall))
return true;
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
if (CheckFloatOrHalfRepresentations(this, TheCall))
return true;
break;
}
Expand Down
22 changes: 0 additions & 22 deletions clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s



// CHECK-LABEL: builtin_lerp_half_scalar
// CHECK: %3 = fsub double %conv1, %conv
// CHECK: %4 = fmul double %conv2, %3
// CHECK: %dx.lerp = fadd double %conv, %4
// CHECK: %conv3 = fptrunc double %dx.lerp to half
// CHECK: ret half %conv3
half builtin_lerp_half_scalar (half p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}

// CHECK-LABEL: builtin_lerp_float_scalar
// CHECK: %3 = fsub double %conv1, %conv
// CHECK: %4 = fmul double %conv2, %3
// CHECK: %dx.lerp = fadd double %conv, %4
// CHECK: %conv3 = fptrunc double %dx.lerp to float
// CHECK: ret float %conv3
float builtin_lerp_float_scalar ( float p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}

// CHECK-LABEL: builtin_lerp_half_vector
// CHECK: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
// CHECK: ret <3 x half> %dx.lerp
Expand Down
27 changes: 11 additions & 16 deletions clang/test/CodeGenHLSL/builtins/lerp.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,46 @@
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF

// NATIVE_HALF: %3 = fsub half %1, %0
// NATIVE_HALF: %4 = fmul half %2, %3
// NATIVE_HALF: %dx.lerp = fadd half %0, %4

// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
// NATIVE_HALF: ret half %dx.lerp
// NO_HALF: %3 = fsub float %1, %0
// NO_HALF: %4 = fmul float %2, %3
// NO_HALF: %dx.lerp = fadd float %0, %4
// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
// NO_HALF: ret float %dx.lerp
half test_lerp_half(half p0) { return lerp(p0, p0, p0); }

// NATIVE_HALF: %dx.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
// NATIVE_HALF: ret <2 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
// NO_HALF: ret <2 x float> %dx.lerp
half2 test_lerp_half2(half2 p0, half2 p1) { return lerp(p0, p0, p0); }
half2 test_lerp_half2(half2 p0) { return lerp(p0, p0, p0); }

// NATIVE_HALF: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
// NATIVE_HALF: ret <3 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
// NO_HALF: ret <3 x float> %dx.lerp
half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
half3 test_lerp_half3(half3 p0) { return lerp(p0, p0, p0); }

// NATIVE_HALF: %dx.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
// NATIVE_HALF: ret <4 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
// NO_HALF: ret <4 x float> %dx.lerp
half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
half4 test_lerp_half4(half4 p0) { return lerp(p0, p0, p0); }

// CHECK: %3 = fsub float %1, %0
// CHECK: %4 = fmul float %2, %3
// CHECK: %dx.lerp = fadd float %0, %4
// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
// CHECK: ret float %dx.lerp
float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
float test_lerp_float(float p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
// CHECK: ret <2 x float> %dx.lerp
float2 test_lerp_float2(float2 p0, float2 p1) { return lerp(p0, p0, p0); }
float2 test_lerp_float2(float2 p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
// CHECK: ret <3 x float> %dx.lerp
float3 test_lerp_float3(float3 p0, float3 p1) { return lerp(p0, p0, p0); }
float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
// CHECK: ret <4 x float> %dx.lerp
float4 test_lerp_float4(float4 p0, float4 p1) { return lerp(p0, p0, p0); }
float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
// CHECK: ret <2 x float> %dx.lerp
Expand Down
12 changes: 12 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,15 @@ float2 builtin_frac_int2_to_float2_promotion(int2 p1) {
return __builtin_hlsl_elementwise_frac(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}

// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
half builtin_frac_half_scalar (half p0) {
return __builtin_hlsl_elementwise_frac (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_frac_float_scalar ( float p0) {
return __builtin_hlsl_elementwise_frac (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

11 changes: 11 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,14 @@ bool2 builtin_isinf_int2_to_float2_promotion(int2 p1) {
return __builtin_hlsl_elementwise_isinf(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}

// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
half builtin_isinf_half_scalar (half p0) {
return __builtin_hlsl_elementwise_isinf (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_isinf_float_scalar ( float p0) {
return __builtin_hlsl_elementwise_isinf (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}
17 changes: 15 additions & 2 deletions clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,18 @@ float builtin_lerp_int_to_float_promotion(float p0, int p1) {

float4 test_lerp_int4(int4 p0, int4 p1, int4 p2) {
return __builtin_hlsl_lerp(p0, p1, p2);
// expected-error@-1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
}
// expected-error@-1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
}

// note: DefaultVariadicArgumentPromotion --> DefaultArgumentPromotion has already promoted to double
// we don't know anymore that the input was half when __builtin_hlsl_lerp is called so we default to float
// for expected type
half builtin_lerp_half_scalar (half p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_lerp_float_scalar ( float p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}
11 changes: 11 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,14 @@ float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
return __builtin_hlsl_elementwise_rsqrt(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}

// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
half builtin_rsqrt_half_scalar (half p0) {
return __builtin_hlsl_elementwise_rsqrt (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_rsqrt_float_scalar ( float p0) {
return __builtin_hlsl_elementwise_rsqrt (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}
4 changes: 1 addition & 3 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def int_dx_isinf :
DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
[llvm_anyfloat_ty]>;

def int_dx_lerp :
Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;

def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
DirectXSubtarget.cpp
DirectXTargetMachine.cpp
DXContainerGlobals.cpp
DXILIntrinsicExpansion.cpp
DXILMetadata.cpp
DXILOpBuilder.cpp
DXILOpLowering.cpp
Expand Down

0 comments on commit de1a97d

Please sign in to comment.