-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HLSL] Implement rsqrt
intrinsic
#84820
Conversation
@llvm/pr-subscribers-backend-directx @llvm/pr-subscribers-clang Author: Farzon Lotfi (farzonl) ChangesThis change implements #70074
Full diff: https://github.com/llvm/llvm-project/pull/84820.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9c703377ca8d3e..de0cfb4e46b8bd 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4590,6 +4590,12 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..d2c83a5e405f42 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18077,6 +18077,14 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/Op0->getType(), Intrinsic::dx_rcp,
ArrayRef<Value *>{Op0}, nullptr, "dx.rcp");
}
+ case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {
+ Value *Op0 = EmitScalarExpr(E->getArg(0));
+ if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+ llvm_unreachable("rsqrt operand must have a float representation");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
+ ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
+ }
}
return nullptr;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 45f8544392584e..f88aa4a5d5c644 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1153,6 +1153,38 @@ double3 rcp(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rcp)
double4 rcp(double4);
+//===----------------------------------------------------------------------===//
+// rsqrt builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T rsqrt(T x)
+/// \brief RReturns the reciprocal of the square root of the specified value \a x.
+/// \param x The specified input value.
+///
+/// This function uses the following formula: 1 / sqrt(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half rsqrt(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half2 rsqrt(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half3 rsqrt(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half4 rsqrt(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float rsqrt(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float2 rsqrt(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float3 rsqrt(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float4 rsqrt(float4);
+
//===----------------------------------------------------------------------===//
// round builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a5f42b630c3fa2..0dafff47ab4040 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5285,6 +5285,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
case Builtin::BI__builtin_hlsl_elementwise_rcp:
case Builtin::BI__builtin_hlsl_elementwise_frac: {
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
diff --git a/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
new file mode 100644
index 00000000000000..c87a8c404b08e1
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %dx.rsqrt = call half @llvm.dx.rsqrt.f16(
+// NATIVE_HALF: ret half %dx.rsqrt
+// NO_HALF: define noundef float @"?test_rsqrt_half@@YA$halff@$halff@@Z"(
+// NO_HALF: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// NO_HALF: ret float %dx.rsqrt
+half test_rsqrt_half(half p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <2 x half> @llvm.dx.rsqrt.v2f16
+// NATIVE_HALF: ret <2 x half> %dx.rsqrt
+// NO_HALF: define noundef <2 x float> @
+// NO_HALF: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32(
+// NO_HALF: ret <2 x float> %dx.rsqrt
+half2 test_rsqrt_half2(half2 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <3 x half> @llvm.dx.rsqrt.v3f16
+// NATIVE_HALF: ret <3 x half> %dx.rsqrt
+// NO_HALF: define noundef <3 x float> @
+// NO_HALF: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32(
+// NO_HALF: ret <3 x float> %dx.rsqrt
+half3 test_rsqrt_half3(half3 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <4 x half> @llvm.dx.rsqrt.v4f16
+// NATIVE_HALF: ret <4 x half> %dx.rsqrt
+// NO_HALF: define noundef <4 x float> @
+// NO_HALF: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32(
+// NO_HALF: ret <4 x float> %dx.rsqrt
+half4 test_rsqrt_half4(half4 p0) { return rsqrt(p0); }
+
+// CHECK: define noundef float @
+// CHECK: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// CHECK: ret float %dx.rsqrt
+float test_rsqrt_float(float p0) { return rsqrt(p0); }
+// CHECK: define noundef <2 x float> @
+// CHECK: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32
+// CHECK: ret <2 x float> %dx.rsqrt
+float2 test_rsqrt_float2(float2 p0) { return rsqrt(p0); }
+// CHECK: define noundef <3 x float> @
+// CHECK: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32
+// CHECK: ret <3 x float> %dx.rsqrt
+float3 test_rsqrt_float3(float3 p0) { return rsqrt(p0); }
+// CHECK: define noundef <4 x float> @
+// CHECK: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32
+// CHECK: ret <4 x float> %dx.rsqrt
+float4 test_rsqrt_float4(float4 p0) { return rsqrt(p0); }
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-warning.ll b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
new file mode 100644
index 00000000000000..5ecdde4d70e51f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
@@ -0,0 +1,49 @@
+; ModuleID = 'D:\projects\llvm-project\clang\test\SemaHLSL\BuiltIns\dot-warning.hlsl'
+source_filename = "D:\\projects\\llvm-project\\clang\\test\\SemaHLSL\\BuiltIns\\dot-warning.hlsl"
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.3-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_vector_elem_size_reduction@@YAMT?$__vector@J$01@__clang@@M@Z"(<2 x i64> noundef %p0, float noundef %p1) #0 {
+entry:
+ %p1.addr = alloca float, align 4
+ %p0.addr = alloca <2 x i64>, align 16
+ store float %p1, ptr %p1.addr, align 4
+ store <2 x i64> %p0, ptr %p0.addr, align 16
+ %0 = load <2 x i64>, ptr %p0.addr, align 16
+ %conv = sitofp <2 x i64> %0 to <2 x float>
+ %1 = load float, ptr %p1.addr, align 4
+ %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+ %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+ %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+ ret float %dx.dot
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_int_vector_elem_size_reduction@@YAMT?$__vector@H$01@__clang@@M@Z"(<2 x i32> noundef %p0, float noundef %p1) #0 {
+entry:
+ %p1.addr = alloca float, align 4
+ %p0.addr = alloca <2 x i32>, align 8
+ store float %p1, ptr %p1.addr, align 4
+ store <2 x i32> %p0, ptr %p0.addr, align 8
+ %0 = load <2 x i32>, ptr %p0.addr, align 8
+ %conv = sitofp <2 x i32> %0 to <2 x float>
+ %1 = load float, ptr %p1.addr, align 4
+ %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+ %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+ %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+ ret float %dx.dot
+}
+
+attributes #0 = { noinline nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (https://github.com/farzonl/llvm-project.git f40562c7b4224e00da2ff2e13d175abfaac68532)"}
diff --git a/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
new file mode 100644
index 00000000000000..c74c502bd7a26f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
@@ -0,0 +1,27 @@
+
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float test_too_few_arg() {
+ return __builtin_hlsl_elementwise_rsqrt();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_elementwise_rsqrt(p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+float builtin_bool_to_float_type_promotion(bool p1) {
+ return __builtin_hlsl_elementwise_rsqrt(p1);
+ // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+float builtin_rsqrt_int_to_float_promotion(int p1) {
+ return __builtin_hlsl_elementwise_rsqrt(p1);
+ // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
+ return __builtin_hlsl_elementwise_rsqrt(p1);
+ // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 7229292e377a83..366dedda2b3f73 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -37,4 +37,5 @@ def int_dx_lerp :
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9536a01e125bb3..942715f6ad80d3 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -224,6 +224,9 @@ def Exp2 : DXILOpMapping<21, unary, int_exp2,
def Frac : DXILOpMapping<22, unary, int_dx_frac,
"Returns a fraction from 0 to 1 that represents the "
"decimal part of the input.">;
+def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
+ "Returns the reciprocal of the square root of the specified value."
+ "rsqrt(x) = 1 / sqrt(x).">;
def Round : DXILOpMapping<26, unary, int_round,
"Returns the input rounded to the nearest integer"
"within a floating-point type.">;
diff --git a/llvm/test/CodeGen/DirectX/rsqrt.ll b/llvm/test/CodeGen/DirectX/rsqrt.ll
new file mode 100644
index 00000000000000..818b5985422173
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/rsqrt.ll
@@ -0,0 +1,31 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for rsqrt are generated for float and half.
+; CHECK:call float @dx.op.unary.f32(i32 25, float %{{.*}})
+; CHECK:call half @dx.op.unary.f16(i32 25, half %{{.*}})
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @rsqrt_float(float noundef %a) #0 {
+entry:
+ %a.addr = alloca float, align 4
+ store float %a, ptr %a.addr, align 4
+ %0 = load float, ptr %a.addr, align 4
+ %dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0)
+ ret float %dx.rsqrt
+}
+
+; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
+declare float @llvm.dx.rsqrt.f32(float) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @rsqrt_half(half noundef %a) #0 {
+entry:
+ %a.addr = alloca half, align 2
+ store half %a, ptr %a.addr, align 2
+ %0 = load half, ptr %a.addr, align 2
+ %dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0)
+ ret half %dx.rsqrt
+}
|
@llvm/pr-subscribers-clang-codegen Author: Farzon Lotfi (farzonl) ChangesThis change implements #70074
Full diff: https://github.com/llvm/llvm-project/pull/84820.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9c703377ca8d3e..de0cfb4e46b8bd 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4590,6 +4590,12 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..d2c83a5e405f42 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18077,6 +18077,14 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/Op0->getType(), Intrinsic::dx_rcp,
ArrayRef<Value *>{Op0}, nullptr, "dx.rcp");
}
+ case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {
+ Value *Op0 = EmitScalarExpr(E->getArg(0));
+ if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+ llvm_unreachable("rsqrt operand must have a float representation");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
+ ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
+ }
}
return nullptr;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 45f8544392584e..f88aa4a5d5c644 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1153,6 +1153,38 @@ double3 rcp(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rcp)
double4 rcp(double4);
+//===----------------------------------------------------------------------===//
+// rsqrt builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T rsqrt(T x)
+/// \brief RReturns the reciprocal of the square root of the specified value \a x.
+/// \param x The specified input value.
+///
+/// This function uses the following formula: 1 / sqrt(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half rsqrt(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half2 rsqrt(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half3 rsqrt(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half4 rsqrt(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float rsqrt(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float2 rsqrt(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float3 rsqrt(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float4 rsqrt(float4);
+
//===----------------------------------------------------------------------===//
// round builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a5f42b630c3fa2..0dafff47ab4040 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5285,6 +5285,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
case Builtin::BI__builtin_hlsl_elementwise_rcp:
case Builtin::BI__builtin_hlsl_elementwise_frac: {
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
diff --git a/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
new file mode 100644
index 00000000000000..c87a8c404b08e1
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %dx.rsqrt = call half @llvm.dx.rsqrt.f16(
+// NATIVE_HALF: ret half %dx.rsqrt
+// NO_HALF: define noundef float @"?test_rsqrt_half@@YA$halff@$halff@@Z"(
+// NO_HALF: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// NO_HALF: ret float %dx.rsqrt
+half test_rsqrt_half(half p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <2 x half> @llvm.dx.rsqrt.v2f16
+// NATIVE_HALF: ret <2 x half> %dx.rsqrt
+// NO_HALF: define noundef <2 x float> @
+// NO_HALF: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32(
+// NO_HALF: ret <2 x float> %dx.rsqrt
+half2 test_rsqrt_half2(half2 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <3 x half> @llvm.dx.rsqrt.v3f16
+// NATIVE_HALF: ret <3 x half> %dx.rsqrt
+// NO_HALF: define noundef <3 x float> @
+// NO_HALF: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32(
+// NO_HALF: ret <3 x float> %dx.rsqrt
+half3 test_rsqrt_half3(half3 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <4 x half> @llvm.dx.rsqrt.v4f16
+// NATIVE_HALF: ret <4 x half> %dx.rsqrt
+// NO_HALF: define noundef <4 x float> @
+// NO_HALF: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32(
+// NO_HALF: ret <4 x float> %dx.rsqrt
+half4 test_rsqrt_half4(half4 p0) { return rsqrt(p0); }
+
+// CHECK: define noundef float @
+// CHECK: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// CHECK: ret float %dx.rsqrt
+float test_rsqrt_float(float p0) { return rsqrt(p0); }
+// CHECK: define noundef <2 x float> @
+// CHECK: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32
+// CHECK: ret <2 x float> %dx.rsqrt
+float2 test_rsqrt_float2(float2 p0) { return rsqrt(p0); }
+// CHECK: define noundef <3 x float> @
+// CHECK: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32
+// CHECK: ret <3 x float> %dx.rsqrt
+float3 test_rsqrt_float3(float3 p0) { return rsqrt(p0); }
+// CHECK: define noundef <4 x float> @
+// CHECK: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32
+// CHECK: ret <4 x float> %dx.rsqrt
+float4 test_rsqrt_float4(float4 p0) { return rsqrt(p0); }
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-warning.ll b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
new file mode 100644
index 00000000000000..5ecdde4d70e51f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
@@ -0,0 +1,49 @@
+; ModuleID = 'D:\projects\llvm-project\clang\test\SemaHLSL\BuiltIns\dot-warning.hlsl'
+source_filename = "D:\\projects\\llvm-project\\clang\\test\\SemaHLSL\\BuiltIns\\dot-warning.hlsl"
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.3-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_vector_elem_size_reduction@@YAMT?$__vector@J$01@__clang@@M@Z"(<2 x i64> noundef %p0, float noundef %p1) #0 {
+entry:
+ %p1.addr = alloca float, align 4
+ %p0.addr = alloca <2 x i64>, align 16
+ store float %p1, ptr %p1.addr, align 4
+ store <2 x i64> %p0, ptr %p0.addr, align 16
+ %0 = load <2 x i64>, ptr %p0.addr, align 16
+ %conv = sitofp <2 x i64> %0 to <2 x float>
+ %1 = load float, ptr %p1.addr, align 4
+ %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+ %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+ %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+ ret float %dx.dot
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_int_vector_elem_size_reduction@@YAMT?$__vector@H$01@__clang@@M@Z"(<2 x i32> noundef %p0, float noundef %p1) #0 {
+entry:
+ %p1.addr = alloca float, align 4
+ %p0.addr = alloca <2 x i32>, align 8
+ store float %p1, ptr %p1.addr, align 4
+ store <2 x i32> %p0, ptr %p0.addr, align 8
+ %0 = load <2 x i32>, ptr %p0.addr, align 8
+ %conv = sitofp <2 x i32> %0 to <2 x float>
+ %1 = load float, ptr %p1.addr, align 4
+ %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+ %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+ %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+ ret float %dx.dot
+}
+
+attributes #0 = { noinline nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (https://github.com/farzonl/llvm-project.git f40562c7b4224e00da2ff2e13d175abfaac68532)"}
diff --git a/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
new file mode 100644
index 00000000000000..c74c502bd7a26f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
@@ -0,0 +1,27 @@
+
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float test_too_few_arg() {
+ return __builtin_hlsl_elementwise_rsqrt();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_elementwise_rsqrt(p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+float builtin_bool_to_float_type_promotion(bool p1) {
+ return __builtin_hlsl_elementwise_rsqrt(p1);
+ // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+float builtin_rsqrt_int_to_float_promotion(int p1) {
+ return __builtin_hlsl_elementwise_rsqrt(p1);
+ // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
+ return __builtin_hlsl_elementwise_rsqrt(p1);
+ // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 7229292e377a83..366dedda2b3f73 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -37,4 +37,5 @@ def int_dx_lerp :
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9536a01e125bb3..942715f6ad80d3 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -224,6 +224,9 @@ def Exp2 : DXILOpMapping<21, unary, int_exp2,
def Frac : DXILOpMapping<22, unary, int_dx_frac,
"Returns a fraction from 0 to 1 that represents the "
"decimal part of the input.">;
+def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
+ "Returns the reciprocal of the square root of the specified value."
+ "rsqrt(x) = 1 / sqrt(x).">;
def Round : DXILOpMapping<26, unary, int_round,
"Returns the input rounded to the nearest integer"
"within a floating-point type.">;
diff --git a/llvm/test/CodeGen/DirectX/rsqrt.ll b/llvm/test/CodeGen/DirectX/rsqrt.ll
new file mode 100644
index 00000000000000..818b5985422173
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/rsqrt.ll
@@ -0,0 +1,31 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for rsqrt are generated for float and half.
+; CHECK:call float @dx.op.unary.f32(i32 25, float %{{.*}})
+; CHECK:call half @dx.op.unary.f16(i32 25, half %{{.*}})
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @rsqrt_float(float noundef %a) #0 {
+entry:
+ %a.addr = alloca float, align 4
+ store float %a, ptr %a.addr, align 4
+ %0 = load float, ptr %a.addr, align 4
+ %dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0)
+ ret float %dx.rsqrt
+}
+
+; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
+declare float @llvm.dx.rsqrt.f32(float) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @rsqrt_half(half noundef %a) #0 {
+entry:
+ %a.addr = alloca half, align 2
+ store half %a, ptr %a.addr, align 2
+ %0 = load half, ptr %a.addr, align 2
+ %dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0)
+ ret half %dx.rsqrt
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
52077bb
to
a46ecde
Compare
4076546
to
1e8d135
Compare
d7de6c5
to
e8a619b
Compare
This change implements #70074
hlsl_intrinsics.h
- add thersqrt
apiDXIL.td
add the llvm intrinsic to DXIL op lowering map.Builtins.td
- add an hlsl builtin for rsqrt.CGBuiltin.cpp
add the ir generation for the rsqrt intrinsic.SemaChecking.cpp
- reuse the one arg float only checks.IntrinsicsDirectX.td
-add anrsqrt
intrinsic.