diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index df74026c5d2d5..e3432f7925ba1 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> { let Prototype = "void*(unsigned char)"; } +def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_dot"]; + let Attributes = [NoThrow, Const]; + let Prototype = "void(...)"; +} + // Builtins for XRay. def XRayCustomEvent : Builtin { let Spellings = ["__xray_customevent"]; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index c966a99c51968..ef4b93fac95ce 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -14063,6 +14063,7 @@ class Sema final { bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall); bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); + bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum); bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall); @@ -14128,6 +14129,8 @@ class Sema final { bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc); + bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res); + bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall); bool SemaBuiltinElementwiseMath(CallExpr *TheCall); bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall); bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall); diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 734eb5a035ca4..54d7451a9d622 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -44,6 +44,7 @@ #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IntrinsicsBPF.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/IntrinsicsHexagon.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/IntrinsicsPowerPC.h" @@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr"); } + // EmitHLSLBuiltinExpr will check getLangOpts().HLSL + if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E)) + return RValue::get(V); + if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice) return EmitHipStdParUnsupportedBuiltin(this, FD); @@ -17959,6 +17964,52 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments, return Arg; } +Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, + const CallExpr *E) { + if (!getLangOpts().HLSL) + return nullptr; + + switch (BuiltinID) { + case Builtin::BI__builtin_hlsl_dot: { + Value *Op0 = EmitScalarExpr(E->getArg(0)); + Value *Op1 = EmitScalarExpr(E->getArg(1)); + llvm::Type *T0 = Op0->getType(); + llvm::Type *T1 = Op1->getType(); + if (!T0->isVectorTy() && !T1->isVectorTy()) { + if (T0->isFloatingPointTy()) + return Builder.CreateFMul(Op0, Op1, "dx.dot"); + + if (T0->isIntegerTy()) + return Builder.CreateMul(Op0, Op1, "dx.dot"); + + // Bools should have been promoted + llvm_unreachable( + "Scalar dot product is only supported on ints and floats."); + } + // A VectorSplat should have happened + assert(T0->isVectorTy() && T1->isVectorTy() && + "Dot product of vector and scalar is not supported."); + + // A vector sext or sitofp should have happened + assert(T0->getScalarType() == T1->getScalarType() && + "Dot product of vectors need the same element types."); + + [[maybe_unused]] auto *VecTy0 = + E->getArg(0)->getType()->getAs(); + [[maybe_unused]] auto *VecTy1 = + E->getArg(1)->getType()->getAs(); + // A HLSLVectorTruncation should have happend + assert(VecTy0->getNumElements() == VecTy1->getNumElements() && + "Dot product requires vectors to be of the same size."); + + return Builder.CreateIntrinsic( + /*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot, + ArrayRef{Op0, Op1}, nullptr, "dx.dot"); + } break; + } + return nullptr; +} + Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E) { llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent; diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 92ce0edeaf9e9..b2800f699ff4b 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache { llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E); + llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx, const CallExpr *E); llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index f87ac97799796..08e5d981a4a4c 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -179,6 +179,104 @@ double3 cos(double3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos) double4 cos(double4); +//===----------------------------------------------------------------------===// +// dot product builtins +//===----------------------------------------------------------------------===// + +/// \fn K dot(T X, T Y) +/// \brief Return the dot product (a scalar value) of \a X and \a Y. +/// \param X The X input value. +/// \param Y The Y input value. + +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half, half); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half2, half2); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half3, half3); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half4, half4); + +#ifdef __HLSL_ENABLE_16_BIT +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t, int16_t); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t2, int16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t3, int16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t4, int16_t4); + +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t, uint16_t); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t2, uint16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t3, uint16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t4, uint16_t4); +#endif + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float, float); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float2, float2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float3, float3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float4, float4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +double dot(double, double); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int, int); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int2, int2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int3, int3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int4, int4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint, uint); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint2, uint2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint3, uint3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint4, uint4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t, int64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t2, int64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t3, int64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t4, int64_t4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t, uint64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t2, uint64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t3, uint64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t4, uint64_t4); + //===----------------------------------------------------------------------===// // floor builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 7fa295ebd9404..984088e345c80 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, // not a valid type, emit an error message and return true. Otherwise return // false. static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc, - QualType Ty) { - if (!Ty->getAs() && !ConstantMatrixType::isValidElementType(Ty)) { + QualType ArgTy, int ArgIndex) { + if (!ArgTy->getAs() && + !ConstantMatrixType::isValidElementType(ArgTy)) { return S.Diag(Loc, diag::err_builtin_invalid_arg_type) - << 1 << /* vector, integer or float ty*/ 0 << Ty; + << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy; } return false; @@ -2961,6 +2962,9 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID, } } + if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) + return ExprError(); + // Since the target specific builtins for each arch overlap, only check those // of the arch we are compiling for. if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) { @@ -5161,6 +5165,70 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) { return false; } +// Helper function for CheckHLSLBuiltinFunctionCall +bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { + assert(TheCall->getNumArgs() > 1); + ExprResult A = TheCall->getArg(0); + ExprResult B = TheCall->getArg(1); + QualType ArgTyA = A.get()->getType(); + QualType ArgTyB = B.get()->getType(); + auto *VecTyA = ArgTyA->getAs(); + auto *VecTyB = ArgTyB->getAs(); + SourceLocation BuiltinLoc = TheCall->getBeginLoc(); + if (VecTyA == nullptr && VecTyB == nullptr) + return false; + + if (VecTyA && VecTyB) { + bool retValue = false; + if (VecTyA->getElementType() != VecTyB->getElementType()) { + // Note: type promotion is intended to be handeled via the intrinsics + // and not the builtin itself. + S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector) + << TheCall->getDirectCallee() + << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc()); + retValue = true; + } + if (VecTyA->getNumElements() != VecTyB->getNumElements()) { + // if we get here a HLSLVectorTruncation is needed. + S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector) + << TheCall->getDirectCallee() + << SourceRange(TheCall->getArg(0)->getBeginLoc(), + TheCall->getArg(1)->getEndLoc()); + retValue = true; + } + + if (retValue) + TheCall->setType(VecTyA->getElementType()); + + return retValue; + } + + // Note: if we get here one of the args is a scalar which + // requires a VectorSplat on Arg0 or Arg1 + S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector) + << TheCall->getDirectCallee() + << SourceRange(TheCall->getArg(0)->getBeginLoc(), + TheCall->getArg(1)->getEndLoc()); + return true; +} + +// Note: returning true in this case results in CheckBuiltinFunctionCall +// returning an ExprError +bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { + switch (BuiltinID) { + case Builtin::BI__builtin_hlsl_dot: { + if (checkArgCount(*this, TheCall, 2)) + return true; + if (CheckVectorElementCallArgs(this, TheCall)) + return true; + if (SemaBuiltinVectorToScalarMath(TheCall)) + return true; + break; + } + } + return false; +} + bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { // position of memory order and scope arguments in the builtin @@ -19594,7 +19662,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) { TheCall->setArg(0, A.get()); QualType TyA = A.get()->getType(); - if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA)) + if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1)) return true; TheCall->setType(TyA); @@ -19602,6 +19670,27 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) { } bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { + QualType Res; + if (SemaBuiltinVectorMath(TheCall, Res)) + return true; + TheCall->setType(Res); + return false; +} + +bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) { + QualType Res; + if (SemaBuiltinVectorMath(TheCall, Res)) + return true; + + if (auto *VecTy0 = Res->getAs()) + TheCall->setType(VecTy0->getElementType()); + else + TheCall->setType(Res); + + return false; +} + +bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) { if (checkArgCount(*this, TheCall, 2)) return true; @@ -19609,8 +19698,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { ExprResult B = TheCall->getArg(1); // Do standard promotions between the two arguments, returning their common // type. - QualType Res = - UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison); + Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison); if (A.isInvalid() || B.isInvalid()) return true; @@ -19622,12 +19710,11 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { diag::err_typecheck_call_different_arg_types) << TyA << TyB; - if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA)) + if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1)) return true; TheCall->setArg(0, A.get()); TheCall->setArg(1, B.get()); - TheCall->setType(Res); return false; } diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl new file mode 100644 index 0000000000000..9881dabc3a110 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +// CHECK-LABEL: builtin_bool_to_float_type_promotion +// CHECK: %conv1 = uitofp i1 %tobool to double +// CHECK: %dx.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: ret float %conv2 +float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion +// CHECK: %conv = uitofp i1 %tobool to double +// CHECK: %conv1 = fpext float %1 to double +// CHECK: %dx.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: ret float %conv2 +float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK-LABEL: builtin_dot_int_to_float_promotion +// CHECK: %conv = fpext float %0 to double +// CHECK: %conv1 = sitofp i32 %1 to double +// CHECK: dx.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: ret float %conv2 +float builtin_dot_int_to_float_promotion ( float p0, int p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl new file mode 100644 index 0000000000000..b2c1bae31d13b --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -0,0 +1,265 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ +// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ +// RUN: --check-prefixes=CHECK,NATIVE_HALF +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF + +#ifdef __HLSL_ENABLE_16_BIT +// NATIVE_HALF: %dx.dot = mul i16 %0, %1 +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short ( int16_t p0, int16_t p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short2 ( int16_t2 p0, int16_t2 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short3 ( int16_t3 p0, int16_t3 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short4 ( int16_t4 p0, int16_t4 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = mul i16 %0, %1 +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort ( uint16_t p0, uint16_t p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort2 ( uint16_t2 p0, uint16_t2 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) { + return dot ( p0, p1 ); +} +#endif + +// CHECK: %dx.dot = mul i32 %0, %1 +// CHECK: ret i32 %dx.dot +int test_dot_int ( int p0, int p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: ret i32 %dx.dot +int test_dot_int2 ( int2 p0, int2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: ret i32 %dx.dot +int test_dot_int3 ( int3 p0, int3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: ret i32 %dx.dot +int test_dot_int4 ( int4 p0, int4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = mul i32 %0, %1 +// CHECK: ret i32 %dx.dot +uint test_dot_uint ( uint p0, uint p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: ret i32 %dx.dot +uint test_dot_uint2 ( uint2 p0, uint2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: ret i32 %dx.dot +uint test_dot_uint3 ( uint3 p0, uint3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: ret i32 %dx.dot +uint test_dot_uint4 ( uint4 p0, uint4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = mul i64 %0, %1 +// CHECK: ret i64 %dx.dot +int64_t test_dot_long ( int64_t p0, int64_t p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: ret i64 %dx.dot +int64_t test_dot_long2 ( int64_t2 p0, int64_t2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: ret i64 %dx.dot +int64_t test_dot_long3 ( int64_t3 p0, int64_t3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: ret i64 %dx.dot +int64_t test_dot_long4 ( int64_t4 p0, int64_t4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = mul i64 %0, %1 +// CHECK: ret i64 %dx.dot +uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: ret i64 %dx.dot +uint64_t test_dot_ulong2 ( uint64_t2 p0, uint64_t2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: ret i64 %dx.dot +uint64_t test_dot_ulong3 ( uint64_t3 p0, uint64_t3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: ret i64 %dx.dot +uint64_t test_dot_ulong4 ( uint64_t4 p0, uint64_t4 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = fmul half %0, %1 +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = fmul float %0, %1 +// NO_HALF: ret float %dx.dot +half test_dot_half ( half p0, half p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1) +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// NO_HALF: ret float %dx.dot +half test_dot_half2 ( half2 p0, half2 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1) +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// NO_HALF: ret float %dx.dot +half test_dot_half3 ( half3 p0, half3 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1) +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// NO_HALF: ret float %dx.dot +half test_dot_half4 ( half4 p0, half4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = fmul float %0, %1 +// CHECK: ret float %dx.dot +float test_dot_float ( float p0, float p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float2 ( float2 p0, float2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float3 ( float3 p0, float3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float4 ( float4 p0, float4 p1) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float2_splat ( float p0, float2 p1 ) { + return dot( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float3_splat ( float p0, float3 p1 ) { + return dot( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float4_splat ( float p0, float4 p1 ) { + return dot( p0, p1 ); +} + +// CHECK: %conv = sitofp i32 %1 to float +// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0 +// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat) +// CHECK: ret float %dx.dot +float test_builtin_dot_float2_int_splat ( float2 p0, int p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %conv = sitofp i32 %1 to float +// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0 +// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat) +// CHECK: ret float %dx.dot +float test_builtin_dot_float3_int_splat ( float3 p0, int p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = fmul double %0, %1 +// CHECK: ret double %dx.dot +double test_dot_double ( double p0, double p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %conv = zext i1 %tobool to i32 +// CHECK: %dx.dot = mul i32 %conv, %1 +// CHECK: ret i32 %dx.dot +int test_dot_bool_scalar_arg0_type_promotion ( bool p0, int p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %conv = zext i1 %tobool to i32 +// CHECK: %dx.dot = mul i32 %0, %conv +// CHECK: ret i32 %dx.dot +int test_dot_bool_scalar_arg1_type_promotion ( int p0, bool p1 ) { + return dot ( p0, p1 ); +} diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl new file mode 100644 index 0000000000000..54d093aa7ce3a --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl @@ -0,0 +1,109 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected + +float test_no_second_arg ( float2 p0) { + return __builtin_hlsl_dot ( p0 ); + // expected-error@-1 {{too few arguments to function call, expected 2, have 1}} +} + +float test_too_many_arg ( float2 p0) { + return __builtin_hlsl_dot ( p0, p0, p0 ); + // expected-error@-1 {{too many arguments to function call, expected 2, have 3}} +} + +float test_dot_no_second_arg ( float2 p0) { + return dot ( p0 ); + // expected-error@-1 {{no matching function for call to 'dot'}} +} + +float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) { + return dot ( p0, p1 ); + // expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}} +} + +float test_dot_builtin_vector_size_mismatch ( float3 p0, float2 p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} + +float test_dot_scalar_mismatch ( float p0, int p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{call to 'dot' is ambiguous}} +} + +float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{call to 'dot' is ambiguous}} +} + +//NOTE: for all the *_promotion we are intentionally not handling type promotion in builtins +float test_builtin_dot_vec_int_to_float_promotion ( int2 p0, float2 p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} + +int64_t test_builtin_dot_vec_int_to_int64_promotion( int64_t2 p0, int2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} + +float test_builtin_dot_vec_half_to_float_promotion( float2 p0, half2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} + +#ifdef __HLSL_ENABLE_16_BIT +float test_builtin_dot_vec_int16_to_float_promotion( float2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} + +half test_builtin_dot_vec_int16_to_half_promotion( half2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} + +int test_builtin_dot_vec_int16_to_int_promotion( int2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} + +int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}} +} +#endif + +float test_builtin_dot_float2_splat ( float p0, float2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}} +} + +float test_builtin_dot_float3_splat ( float p0, float3 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}} +} + +float test_builtin_dot_float4_splat ( float p0, float4 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}} +} + +float test_dot_float2_int_splat ( float2 p0, int p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}} +} + +float test_dot_float3_int_splat ( float3 p0, int p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}} +} + +float test_builtin_dot_int_vect_to_float_vec_promotion ( int2 p0, float p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}} +} + +int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}} +} diff --git a/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl b/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl index 135d6cf335c13..8464f1c1a7c2c 100644 --- a/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl +++ b/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl @@ -24,6 +24,46 @@ void Call4(int16_t H) { Fn4(H); } +int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) { + return dot ( p0, p1 ); +} + +float test_dot_scalar_mismatch ( float p0, int p1 ) { + return dot ( p0, p1 ); +} + +float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) { + return dot ( p0, p1 ); +} + +float test_builtin_dot_vec_int_to_float_promotion ( int2 p0, float2 p1 ) { + return dot ( p0, p1 ); +} + +int64_t test_builtin_dot_vec_int_to_int64_promotion( int64_t2 p0, int2 p1 ) { + return dot ( p0, p1 ); +} + +float test_builtin_dot_vec_half_to_float_promotion( float2 p0, half2 p1 ) { + return dot( p0, p1 ); +} + +float test_builtin_dot_vec_int16_to_float_promotion( float2 p0, int16_t2 p1 ) { + return dot( p0, p1 ); +} + +half test_builtin_dot_vec_int16_to_half_promotion( half2 p0, int16_t2 p1 ) { + return dot( p0, p1 ); +} + +int test_builtin_dot_vec_int16_to_int_promotion( int2 p0, int16_t2 p1 ) { + return dot( p0, p1 ); +} + +int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1 ) { + return dot( p0, p1 ); +} + // https://github.com/llvm/llvm-project/issues/81049 // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 2fe4fdfd5953b..c192d4b84417c 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -19,4 +19,9 @@ def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMe def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">, Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>; + +def int_dx_dot : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; }