Skip to content

Commit

Permalink
start of lerp intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
farzonl committed Feb 28, 2024
1 parent 777ac46 commit 56240f0
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 41 deletions.
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4536,6 +4536,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_lerp"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -10266,6 +10266,11 @@ def err_block_on_vm : Error<
def err_sizeless_nonlocal : Error<
"non-local variable with sizeless type %0">;

def err_vec_builtin_non_vector_all : Error<
"all arguments to %0 must be vectors">;
def err_vec_builtin_incompatible_vector_all : Error<
"all arguments to %0 must have vectors of the same type">;

def err_vec_builtin_non_vector : Error<
"first two arguments to %0 must be vectors">;
def err_vec_builtin_incompatible_vector : Error<
Expand Down
40 changes: 40 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18007,6 +18007,46 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
Value *S = EmitScalarExpr(E->getArg(2));
llvm::Type *Xty = X->getType();
llvm::Type *Yty = Y->getType();
llvm::Type *Sty = S->getType();
if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
if (Xty->isFloatingPointTy()) {
auto V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
}
// DXC does this via casting to float should we do the same thing?
if (Xty->isIntegerTy()) {
auto V = Builder.CreateSub(Y, X);
V = Builder.CreateMul(S, V);
return Builder.CreateAdd(X, V, "dx.lerp");
}
// Bools should have been promoted
llvm_unreachable("Scalar Lerp is only supported on ints and floats.");
}
// A VectorSplat should have happened
assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
"Lerp of vector and scalar is not supported.");

[[maybe_unused]] auto *XVecTy =
E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *YVecTy =
E->getArg(1)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *SVecTy =
E->getArg(2)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
SVecTy->getNumElements() &&
"Lerp requires vectors to be of the same size.");
return Builder.CreateIntrinsic(
/*ReturnType*/ Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
nullptr, "dx.lerp");
}
}
return nullptr;
}
Expand Down
36 changes: 36 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,42 @@ double3 floor(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_floor)
double4 floor(double4);

//===----------------------------------------------------------------------===//
// lerp builtins
//===----------------------------------------------------------------------===//

/// \fn T lerp(T x, T y, T s)
/// \brief Returns the linear interpolation of x to y by s.
/// \param x [in] The first-floating point value.
/// \param y [in] The second-floating point value.
/// \param s [in] A value that linearly interpolates between the x parameter and
/// the y parameter.
///
/// Linear interpolation is based on the following formula: x*(1-s) + y*s which
/// can equivalently be written as x + s(y-x).

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
half lerp(half, half, half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
half2 lerp(half2, half2, half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
half3 lerp(half3, half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
half4 lerp(half4, half4, half4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
float lerp(float, float, float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
float2 lerp(float2, float2, float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
float3 lerp(float3, float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
float4 lerp(float4, float4, float4);

//===----------------------------------------------------------------------===//
// log builtins
//===----------------------------------------------------------------------===//
Expand Down
69 changes: 42 additions & 27 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5191,43 +5191,49 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);

QualType ArgTyA = A.get()->getType();
QualType ArgTyB = B.get()->getType();

auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
if (VecTyA == nullptr && VecTyB == nullptr)
return false;

if (VecTyA && VecTyB) {
bool retValue = false;
if (VecTyA->getElementType() != VecTyB->getElementType()) {
// Note: type promotion is intended to be handeled via the intrinsics
// and not the builtin itself.
S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
retValue = true;
}
if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
// if we get here a HLSLVectorTruncation is needed.
S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
retValue = true;
}
for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
ExprResult B = TheCall->getArg(i);
QualType ArgTyB = B.get()->getType();
auto *VecTyB = ArgTyB->getAs<VectorType>();
if (VecTyA == nullptr && VecTyB == nullptr)
return false;

if (VecTyA && VecTyB) {
bool retValue = false;
if (VecTyA->getElementType() != VecTyB->getElementType()) {
// Note: type promotion is intended to be handeled via the intrinsics
// and not the builtin itself.
S->Diag(TheCall->getBeginLoc(),
diag::err_vec_builtin_incompatible_vector_all)
<< TheCall->getDirectCallee()
<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
retValue = true;
}
if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
// if we get here a HLSLVectorTruncation is needed.
S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector_all)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
retValue = true;
}

if (retValue)
TheCall->setType(VecTyA->getElementType());
if (!retValue)
TheCall->setType(VecTyA->getElementType());

return retValue;
return retValue;
}
}

// Note: if we get here one of the args is a scalar which
// requires a VectorSplat on Arg0 or Arg1
S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector_all)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
Expand All @@ -5247,6 +5253,15 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_lerp: {
if (checkArgCount(*this, TheCall, 3))
return true;
if (CheckVectorElementCallArgs(this, TheCall))
return true;
if (SemaBuiltinElementwiseTernaryMath(TheCall))
return true;
break;
}
}
return false;
}
Expand Down
37 changes: 37 additions & 0 deletions clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s



// CHECK-LABEL: builtin_lerp_half_scalar
// CHECK: %3 = fsub double %conv1, %conv
// CHECK: %4 = fmul double %conv2, %3
// CHECK: %dx.lerp = fadd double %conv, %4
// CHECK: %conv3 = fptrunc double %dx.lerp to half
// CHECK: ret half %conv3
half builtin_lerp_half_scalar (half p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}

// CHECK-LABEL: builtin_lerp_float_scalar
// CHECK: %3 = fsub double %conv1, %conv
// CHECK: %4 = fmul double %conv2, %3
// CHECK: %dx.lerp = fadd double %conv, %4
// CHECK: %conv3 = fptrunc double %dx.lerp to float
// CHECK: ret float %conv3
float builtin_lerp_float_scalar ( float p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}

// CHECK-LABEL: builtin_lerp_half_vector
// CHECK: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
// CHECK: ret <3 x half> %dx.lerp
half3 builtin_lerp_half_vector (half3 p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}

// CHECK-LABEL: builtin_lerp_floar_vector
// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
// CHECK: ret <2 x float> %dx.lerp
float2 builtin_lerp_floar_vector ( float2 p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}
105 changes: 105 additions & 0 deletions clang/test/CodeGenHLSL/builtins/lerp.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF

// NATIVE_HALF: %3 = fsub half %1, %0
// NATIVE_HALF: %4 = fmul half %2, %3
// NATIVE_HALF: %dx.lerp = fadd half %0, %4
// NATIVE_HALF: ret half %dx.lerp
// NO_HALF: %3 = fsub float %1, %0
// NO_HALF: %4 = fmul float %2, %3
// NO_HALF: %dx.lerp = fadd float %0, %4
// NO_HALF: ret float %dx.lerp
half test_lerp_half ( half p0) {
return lerp ( p0, p0, p0 );
}

// NATIVE_HALF: %dx.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
// NATIVE_HALF: ret <2 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
// NO_HALF: ret <2 x float> %dx.lerp
half2 test_lerp_half2 ( half2 p0, half2 p1 ) {
return lerp ( p0, p0, p0 );
}

// NATIVE_HALF: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
// NATIVE_HALF: ret <3 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
// NO_HALF: ret <3 x float> %dx.lerp
half3 test_lerp_half3 ( half3 p0, half3 p1 ) {
return lerp ( p0, p0, p0 );
}

// NATIVE_HALF: %dx.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
// NATIVE_HALF: ret <4 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
// NO_HALF: ret <4 x float> %dx.lerp
half4 test_lerp_half4 ( half4 p0, half4 p1 ) {
return lerp ( p0, p0, p0 );
}

// CHECK: %3 = fsub float %1, %0
// CHECK: %4 = fmul float %2, %3
// CHECK: %dx.lerp = fadd float %0, %4
// CHECK: ret float %dx.lerp
float test_lerp_float ( float p0, float p1 ) {
return lerp ( p0, p0, p0 );
}

// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
// CHECK: ret <2 x float> %dx.lerp
float2 test_lerp_float2 ( float2 p0, float2 p1 ) {
return lerp ( p0, p0, p0 );
}

// CHECK: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
// CHECK: ret <3 x float> %dx.lerp
float3 test_lerp_float3 ( float3 p0, float3 p1 ) {
return lerp ( p0, p0, p0 );
}

// CHECK: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
// CHECK: ret <4 x float> %dx.lerp
float4 test_lerp_float4 ( float4 p0, float4 p1) {
return lerp ( p0, p0, p0 );
}

// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
// CHECK: ret <2 x float> %dx.lerp
float2 test_lerp_float2_splat ( float p0, float2 p1 ) {
return lerp( p0, p1, p1 );
}

// CHECK: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
// CHECK: ret <3 x float> %dx.lerp
float3 test_lerp_float3_splat ( float p0, float3 p1 ) {
return lerp( p0, p1, p1 );
}

// CHECK: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
// CHECK: ret <4 x float> %dx.lerp
float4 test_lerp_float4_splat ( float p0, float4 p1 ) {
return lerp( p0, p1, p1 );
}

// CHECK: %conv = sitofp i32 %2 to float
// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
// CHECK: ret <2 x float> %dx.lerp
float2 test_lerp_float2_int_splat ( float2 p0, int p1 ) {
return lerp ( p0, p0, p1 );
}

// CHECK: %conv = sitofp i32 %2 to float
// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
// CHECK: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
// CHECK: ret <3 x float> %dx.lerp
float3 test_lerp_float3_int_splat ( float3 p0, int p1 ) {
return lerp ( p0, p0, p1 );
}

0 comments on commit 56240f0

Please sign in to comment.