Skip to content

Commit

Permalink
[HLSL] implement mad intrinsic (#83826)
Browse files Browse the repository at this point in the history
This change implements #83736
The dot product lowering needs a tertiary multipy add operation. DXIL
has three mad opcodes for `fmad`(46), `imad`(48), and `umad`(49). Dot
product in DXIL only uses `imad`\ `umad`, but for completeness and
because the hlsl `mad` intrinsic requires it `fmad` was also included.
Two new intrinsics were needed to be created to complete this change.
the `fmad` case already supported by llvm via `fmuladd` intrinsic.

- `hlsl_intrinsics.h` - exposed mad api call.
- `Builtins.td` - exposed a `mad` builtin.
- `Sema.h` - make `tertiary` calls check for float types optional. 
- `CGBuiltin.cpp` - pick the intrinsic for singed\unsigned & float also
reuse `int_fmuladd`.
- `SemaChecking.cpp` - type checks for `__builtin_hlsl_mad`. 
- `IntrinsicsDirectX.td` create the two new intrinsics for
`imad`\`umad`/
- `DXIL.td` - create the llvm intrinsic to  `DXIL` opcode mapping.

---------

Co-authored-by: Farzon Lotfi <farzon@farzon.com>
  • Loading branch information
farzonl and Farzon Lotfi committed Mar 5, 2024
1 parent a730ed7 commit 643b31d
Show file tree
Hide file tree
Showing 12 changed files with 638 additions and 6 deletions.
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4572,6 +4572,12 @@ def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_mad"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14146,7 +14146,8 @@ class Sema final {
bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
bool CheckForFloatArgs = true);
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);

Expand Down
19 changes: 19 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18044,6 +18044,25 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType*/ Op0->getType(), Intrinsic::dx_frac,
ArrayRef<Value *>{Op0}, nullptr, "dx.frac");
}
case Builtin::BI__builtin_hlsl_mad: {
Value *M = EmitScalarExpr(E->getArg(0));
Value *A = EmitScalarExpr(E->getArg(1));
Value *B = EmitScalarExpr(E->getArg(2));
if (E->getArg(0)->getType()->hasFloatingRepresentation()) {
return Builder.CreateIntrinsic(
/*ReturnType*/ M->getType(), Intrinsic::fmuladd,
ArrayRef<Value *>{M, A, B}, nullptr, "dx.fmad");
}
if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
return Builder.CreateIntrinsic(
/*ReturnType*/ M->getType(), Intrinsic::dx_imad,
ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
}
assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
return Builder.CreateIntrinsic(
/*ReturnType*/ M->getType(), Intrinsic::dx_umad,
ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
}
}
return nullptr;
}
Expand Down
105 changes: 105 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,111 @@ double3 log2(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_log2)
double4 log2(double4);

//===----------------------------------------------------------------------===//
// mad builtins
//===----------------------------------------------------------------------===//

/// \fn T mad(T M, T A, T B)
/// \brief The result of \a M * \a A + \a B.
/// \param M The multiplication value.
/// \param A The first addition value.
/// \param B The second addition value.

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half mad(half, half, half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half2 mad(half2, half2, half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half3 mad(half3, half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
half4 mad(half4, half4, half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t mad(int16_t, int16_t, int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t2 mad(int16_t2, int16_t2, int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t3 mad(int16_t3, int16_t3, int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int16_t4 mad(int16_t4, int16_t4, int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t mad(uint16_t, uint16_t, uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t2 mad(uint16_t2, uint16_t2, uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t3 mad(uint16_t3, uint16_t3, uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint16_t4 mad(uint16_t4, uint16_t4, uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int mad(int, int, int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int2 mad(int2, int2, int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int3 mad(int3, int3, int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int4 mad(int4, int4, int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint mad(uint, uint, uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint2 mad(uint2, uint2, uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint3 mad(uint3, uint3, uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint4 mad(uint4, uint4, uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t mad(int64_t, int64_t, int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t2 mad(int64_t2, int64_t2, int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t3 mad(int64_t3, int64_t3, int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
int64_t4 mad(int64_t4, int64_t4, int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t mad(uint64_t, uint64_t, uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t2 mad(uint64_t2, uint64_t2, uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t3 mad(uint64_t3, uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
uint64_t4 mad(uint64_t4, uint64_t4, uint64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float mad(float, float, float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float2 mad(float2, float2, float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float3 mad(float3, float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
float4 mad(float4, float4, float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double mad(double, double, double);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double2 mad(double2, double2, double2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double3 mad(double3, double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
double4 mad(double4, double4, double4);

//===----------------------------------------------------------------------===//
// max builtins
//===----------------------------------------------------------------------===//
Expand Down
28 changes: 23 additions & 5 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5298,6 +5298,14 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_mad: {
if (checkArgCount(*this, TheCall, 3))
return true;
if (CheckVectorElementCallArgs(this, TheCall))
return true;
if (SemaBuiltinElementwiseTernaryMath(TheCall, /*CheckForFloatArgs*/ false))
return true;
}
}
return false;
}
Expand Down Expand Up @@ -19798,7 +19806,8 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
return false;
}

bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
bool CheckForFloatArgs) {
if (checkArgCount(*this, TheCall, 3))
return true;

Expand All @@ -19810,11 +19819,20 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
Args[I] = Converted.get();
}

int ArgOrdinal = 1;
for (Expr *Arg : Args) {
if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
if (CheckForFloatArgs) {
int ArgOrdinal = 1;
for (Expr *Arg : Args) {
if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
Arg->getType(), ArgOrdinal++))
return true;
}
} else {
int ArgOrdinal = 1;
for (Expr *Arg : Args) {
if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
ArgOrdinal++))
return true;
return true;
}
}

for (int I = 1; I < 3; ++I) {
Expand Down

0 comments on commit 643b31d

Please sign in to comment.