Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -13264,6 +13264,9 @@ def err_builtin_invalid_arg_type: Error<
"%plural{0:|: }3"
"%plural{[0,3]:type|:types}1 (was %4)">;

def err_builtin_requires_double_type: Error<
"%ordinal0 argument must be a scalar, vector, or matrix of double type (was %1)">;

def err_bswapg_invalid_bit_width : Error<
"_BitInt type %0 (%1 bits) must be a multiple of 16 bits for byte swapping">;

Expand Down
54 changes: 54 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,60 @@ float3 floor(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_floor)
float4 floor(float4);

//===----------------------------------------------------------------------===//
// fused multiply-add builtins
//===----------------------------------------------------------------------===//

/// \fn double fma(double a, double b, double c)
/// \brief Returns the double-precision fused multiply-addition of a * b + c.
/// \param a The first value in the fused multiply-addition.
/// \param b The second value in the fused multiply-addition.
/// \param c The third value in the fused multiply-addition.

// double scalars and vectors
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double fma(double, double, double);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double2 fma(double2, double2, double2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double3 fma(double3, double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double4 fma(double4, double4, double4);

// double matrices
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double1x1 fma(double1x1, double1x1, double1x1);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double1x2 fma(double1x2, double1x2, double1x2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double1x3 fma(double1x3, double1x3, double1x3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double1x4 fma(double1x4, double1x4, double1x4);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double2x1 fma(double2x1, double2x1, double2x1);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double2x2 fma(double2x2, double2x2, double2x2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double2x3 fma(double2x3, double2x3, double2x3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double2x4 fma(double2x4, double2x4, double2x4);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double3x1 fma(double3x1, double3x1, double3x1);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double3x2 fma(double3x2, double3x2, double3x2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double3x3 fma(double3x3, double3x3, double3x3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double3x4 fma(double3x4, double3x4, double3x4);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double4x1 fma(double4x1, double4x1, double4x1);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double4x2 fma(double4x2, double4x2, double4x2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double4x3 fma(double4x3, double4x3, double4x3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
double4x4 fma(double4x4, double4x4, double4x4);

//===----------------------------------------------------------------------===//
// frac builtins
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 5 additions & 3 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2178,9 +2178,10 @@ static bool
checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr,
int ArgOrdinal) {
QualType EltTy = ArgTy;
if (auto *VecTy = EltTy->getAs<VectorType>())
EltTy = VecTy->getElementType();
clang::QualType EltTy =
ArgTy->isVectorType() ? ArgTy->getAs<VectorType>()->getElementType()
: ArgTy->isMatrixType() ? ArgTy->getAs<MatrixType>()->getElementType()
: ArgTy;

switch (ArgTyRestr) {
case Sema::EltwiseBuiltinArgTyRestriction::None:
Expand All @@ -2192,6 +2193,7 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
break;
case Sema::EltwiseBuiltinArgTyRestriction::FloatTy:
if (!EltTy->isRealFloatingType()) {
// FIXME: make diagnostic's wording correct for matrices
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar or vector */ 5 << /* no int */ 0
<< /* floating-point */ 1 << ArgTy;
Expand Down
35 changes: 35 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3106,6 +3106,25 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
return false;
}

static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->castAs<clang::VectorType>()->getElementType()
: PassedType->isMatrixType()
? PassedType->castAs<clang::MatrixType>()->getElementType()
: PassedType;
if (!BaseType->isDoubleType()) {
// FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
// this custom error.
return S->Diag(Loc, diag::err_builtin_requires_double_type)
<< ArgOrdinal << PassedType;
}

return false;
}

static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
unsigned ArgIndex) {
auto *Arg = TheCall->getArg(ArgIndex);
Expand Down Expand Up @@ -4042,6 +4061,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyA);
break;
}
case Builtin::BI__builtin_elementwise_fma: {
if (SemaRef.checkArgCount(TheCall, 3) ||
CheckAllArgsHaveSameType(&SemaRef, TheCall)) {
return true;
}

if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
CheckAnyDoubleRepresentation))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
// return type is the same as input type
TheCall->setType(ArgTyA);
break;
}
case Builtin::BI__builtin_hlsl_transpose: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
Expand Down
138 changes: 138 additions & 0 deletions clang/test/CodeGenHLSL/builtins/fma.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm \
// RUN: -disable-llvm-passes -o - | FileCheck %s
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm \
// RUN: -disable-llvm-passes -o - | FileCheck %s

// CHECK-LABEL: define {{.*}} double @{{.*}}fma_double{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.fma.f64(double
// CHECK: ret double
double fma_double(double a, double b, double c) { return fma(a, b, c); }

// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
// CHECK: ret <2 x double>
double2 fma_double2(double2 a, double2 b, double2 c) { return fma(a, b, c); }

// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
// CHECK: ret <3 x double>
double3 fma_double3(double3 a, double3 b, double3 c) { return fma(a, b, c); }

// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
// CHECK: ret <4 x double>
double4 fma_double4(double4 a, double4 b, double4 c) { return fma(a, b, c); }

// CHECK-LABEL: define {{.*}} <1 x double> @{{.*}}fma_double1x1{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <1 x double> @llvm.fma.v1f64(<1 x double>
// CHECK: ret <1 x double>
double1x1 fma_double1x1(double1x1 a, double1x1 b, double1x1 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double1x2{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
// CHECK: ret <2 x double>
double1x2 fma_double1x2(double1x2 a, double1x2 b, double1x2 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double1x3{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
// CHECK: ret <3 x double>
double1x3 fma_double1x3(double1x3 a, double1x3 b, double1x3 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double1x4{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
// CHECK: ret <4 x double>
double1x4 fma_double1x4(double1x4 a, double1x4 b, double1x4 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2x1{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
// CHECK: ret <2 x double>
double2x1 fma_double2x1(double2x1 a, double2x1 b, double2x1 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double2x2{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
// CHECK: ret <4 x double>
double2x2 fma_double2x2(double2x2 a, double2x2 b, double2x2 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double2x3{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double>
// CHECK: ret <6 x double>
double2x3 fma_double2x3(double2x3 a, double2x3 b, double2x3 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double2x4{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double>
// CHECK: ret <8 x double>
double2x4 fma_double2x4(double2x4 a, double2x4 b, double2x4 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3x1{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
// CHECK: ret <3 x double>
double3x1 fma_double3x1(double3x1 a, double3x1 b, double3x1 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double3x2{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double>
// CHECK: ret <6 x double>
double3x2 fma_double3x2(double3x2 a, double3x2 b, double3x2 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <9 x double> @{{.*}}fma_double3x3{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <9 x double> @llvm.fma.v9f64(<9 x double>
// CHECK: ret <9 x double>
double3x3 fma_double3x3(double3x3 a, double3x3 b, double3x3 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double3x4{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double>
// CHECK: ret <12 x double>
double3x4 fma_double3x4(double3x4 a, double3x4 b, double3x4 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4x1{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
// CHECK: ret <4 x double>
double4x1 fma_double4x1(double4x1 a, double4x1 b, double4x1 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double4x2{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double>
// CHECK: ret <8 x double>
double4x2 fma_double4x2(double4x2 a, double4x2 b, double4x2 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double4x3{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double>
// CHECK: ret <12 x double>
double4x3 fma_double4x3(double4x3 a, double4x3 b, double4x3 c) {
return fma(a, b, c);
}

// CHECK-LABEL: define {{.*}} <16 x double> @{{.*}}fma_double4x4{{.*}}(
// CHECK: call reassoc nnan ninf nsz arcp afn <16 x double> @llvm.fma.v16f64(<16 x double>
// CHECK: ret <16 x double>
double4x4 fma_double4x4(double4x4 a, double4x4 b, double4x4 c) {
return fma(a, b, c);
}
113 changes: 113 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \
// RUN: -triple dxil-pc-shadermodel6.6-library %s \
// RUN: -emit-llvm-only -disable-llvm-passes -verify \
// RUN: -verify-ignore-unexpected=note
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \
// RUN: -triple spirv-unknown-vulkan-compute %s \
// RUN: -emit-llvm-only -disable-llvm-passes -verify \
// RUN: -verify-ignore-unexpected=note

float bad_float(float a, float b, float c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float')}}
}

float2 bad_float2(float2 a, float2 b, float2 c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}}
}

float2x2 bad_float2x2(float2x2 a, float2x2 b, float2x2 c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}}
}

half bad_half(half a, half b, half c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half')}}
}

half2 bad_half2(half2 a, half2 b, half2 c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2' (aka 'vector<half, 2>'))}}
}

half2x2 bad_half2x2(half2x2 a, half2x2 b, half2x2 c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2x2' (aka 'matrix<half, 2, 2>'))}}
}

double mixed_bad_second(double a, float b, double c) {
return fma(a, b, c);
// expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
}

double mixed_bad_third(double a, double b, half c) {
return fma(a, b, c);
// expected-error@-1 {{arguments are of different types ('double' vs 'half')}}
}

double2 mixed_bad_second_vec(double2 a, float2 b, double2 c) {
return fma(a, b, c);
// expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}}
}

double2 mixed_bad_third_vec(double2 a, double2 b, float2 c) {
return fma(a, b, c);
// expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}}
}

double2x2 mixed_bad_second_mat(double2x2 a, float2x2 b, double2x2 c) {
return fma(a, b, c);
// expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<float, [2 * ...]>')}}
}

double2x2 mixed_bad_third_mat(double2x2 a, double2x2 b, half2x2 c) {
return fma(a, b, c);
// expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<half, [2 * ...]>')}}
}

double shape_mismatch_second(double a, double2 b, double c) {
return fma(a, b, c);
// expected-error@-1 {{call to 'fma' is ambiguous}}
}

double2 shape_mismatch_third(double2 a, double2 b, double c) {
return fma(a, b, c);
// expected-error@-1 {{call to 'fma' is ambiguous}}
}

double2x2 shape_mismatch_scalar_mat(double2x2 a, double b, double2x2 c) {
return fma(a, b, c);
// expected-error@-1 {{call to 'fma' is ambiguous}}
}

double2x2 shape_mismatch_vec_mat(double2x2 a, double2 b, double2x2 c) {
return fma(a, b, c);
// expected-error@-1 {{arguments are of different types ('double2x2' (aka 'matrix<double, 2, 2>') vs 'double2' (aka 'vector<double, 2>'))}}
}

int bad_int(int a, int b, int c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int')}}
}

int2 bad_int2(int2 a, int2 b, int2 c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int2' (aka 'vector<int, 2>'))}}
}

bool bad_bool(bool a, bool b, bool c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool')}}
}

bool2 bad_bool2(bool2 a, bool2 b, bool2 c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2' (aka 'vector<bool, 2>'))}}
}

bool2x2 bad_bool2x2(bool2x2 a, bool2x2 b, bool2x2 c) {
return fma(a, b, c);
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2x2' (aka 'matrix<bool, 2, 2>'))}}
}
Loading
Loading