Skip to content

Commit

Permalink
[HLSL] Implementation of dot intrinsic (#81190)
Browse files Browse the repository at this point in the history
This change implements #70073

HLSL has a dot intrinsic defined here:

https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in
Builtins.td.
This is used to associate all the dot product typdef defined
hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally `a dot b` should be the same as `b dot a`.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via
EmitBuiltinExpr. Dot
product though is language specific intrinsic and so is guarded behind
getLangOpts().HLSL.
The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp

EmitHLSLBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product
on scalars which simplifies down to multiply.

Sema.h & SemaChecking.cpp saw the addition of
CheckHLSLBuiltinFunctionCall, a language specific semantic validation
that can be expanded for other hlsl specific intrinsics.

Fixes #70073
  • Loading branch information
farzonl committed Feb 26, 2024
1 parent b2ebd8b commit 82acec1
Show file tree
Hide file tree
Showing 11 changed files with 703 additions and 8 deletions.
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void*(unsigned char)";
}

def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14063,6 +14063,7 @@ class Sema final {
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
Expand Down Expand Up @@ -14128,6 +14129,8 @@ class Sema final {

bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);

bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
Expand Down
51 changes: 51 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IntrinsicsBPF.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsHexagon.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsPowerPC.h"
Expand Down Expand Up @@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
}

// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
return RValue::get(V);

if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
return EmitHipStdParUnsupportedBuiltin(this, FD);

Expand Down Expand Up @@ -17959,6 +17964,52 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
if (!getLangOpts().HLSL)
return nullptr;

switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Value *Op1 = EmitScalarExpr(E->getArg(1));
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy())
return Builder.CreateFMul(Op0, Op1, "dx.dot");

if (T0->isIntegerTy())
return Builder.CreateMul(Op0, Op1, "dx.dot");

// Bools should have been promoted
llvm_unreachable(
"Scalar dot product is only supported on ints and floats.");
}
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");

// A vector sext or sitofp should have happened
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");

[[maybe_unused]] auto *VecTy0 =
E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");

return Builder.CreateIntrinsic(
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
}
return nullptr;
}

Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
const CallExpr *E);
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
Expand Down
98 changes: 98 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,104 @@ double3 cos(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
double4 cos(double4);

//===----------------------------------------------------------------------===//
// dot product builtins
//===----------------------------------------------------------------------===//

/// \fn K dot(T X, T Y)
/// \brief Return the dot product (a scalar value) of \a X and \a Y.
/// \param X The X input value.
/// \param Y The Y input value.

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half, half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half2, half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half4, half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t, int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t2, int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t3, int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t4, int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t, uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t2, uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t3, uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t4, uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float, float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float2, float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float4, float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
double dot(double, double);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int, int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int2, int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int3, int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int4, int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint, uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint2, uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint3, uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint4, uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t, int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t2, int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t3, int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t4, int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t, uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t2, uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t4, uint64_t4);

//===----------------------------------------------------------------------===//
// floor builtins
//===----------------------------------------------------------------------===//
Expand Down
103 changes: 95 additions & 8 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
// not a valid type, emit an error message and return true. Otherwise return
// false.
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
QualType Ty) {
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
QualType ArgTy, int ArgIndex) {
if (!ArgTy->getAs<VectorType>() &&
!ConstantMatrixType::isValidElementType(ArgTy)) {
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
}

return false;
Expand Down Expand Up @@ -2961,6 +2962,9 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
}
}

if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall))
return ExprError();

// Since the target specific builtins for each arch overlap, only check those
// of the arch we are compiling for.
if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
Expand Down Expand Up @@ -5161,6 +5165,70 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}

// Helper function for CheckHLSLBuiltinFunctionCall
bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
QualType ArgTyA = A.get()->getType();
QualType ArgTyB = B.get()->getType();
auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
if (VecTyA == nullptr && VecTyB == nullptr)
return false;

if (VecTyA && VecTyB) {
bool retValue = false;
if (VecTyA->getElementType() != VecTyB->getElementType()) {
// Note: type promotion is intended to be handeled via the intrinsics
// and not the builtin itself.
S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
retValue = true;
}
if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
// if we get here a HLSLVectorTruncation is needed.
S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
retValue = true;
}

if (retValue)
TheCall->setType(VecTyA->getElementType());

return retValue;
}

// Note: if we get here one of the args is a scalar which
// requires a VectorSplat on Arg0 or Arg1
S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
return true;
}

// Note: returning true in this case results in CheckBuiltinFunctionCall
// returning an ExprError
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
if (checkArgCount(*this, TheCall, 2))
return true;
if (CheckVectorElementCallArgs(this, TheCall))
return true;
if (SemaBuiltinVectorToScalarMath(TheCall))
return true;
break;
}
}
return false;
}

bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
// position of memory order and scope arguments in the builtin
Expand Down Expand Up @@ -19594,23 +19662,43 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
TheCall->setArg(0, A.get());
QualType TyA = A.get()->getType();

if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;

TheCall->setType(TyA);
return false;
}

bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
QualType Res;
if (SemaBuiltinVectorMath(TheCall, Res))
return true;
TheCall->setType(Res);
return false;
}

bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
QualType Res;
if (SemaBuiltinVectorMath(TheCall, Res))
return true;

if (auto *VecTy0 = Res->getAs<VectorType>())
TheCall->setType(VecTy0->getElementType());
else
TheCall->setType(Res);

return false;
}

bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
if (checkArgCount(*this, TheCall, 2))
return true;

ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
// Do standard promotions between the two arguments, returning their common
// type.
QualType Res =
UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
if (A.isInvalid() || B.isInvalid())
return true;

Expand All @@ -19622,12 +19710,11 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
diag::err_typecheck_call_different_arg_types)
<< TyA << TyB;

if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;

TheCall->setArg(0, A.get());
TheCall->setArg(1, B.get());
TheCall->setType(Res);
return false;
}

Expand Down

0 comments on commit 82acec1

Please sign in to comment.