Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL] Implement rsqrt intrinsic #84820

Merged
merged 4 commits into from
Mar 14, 2024
Merged

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Mar 11, 2024

This change implements #70074

  • hlsl_intrinsics.h - add the rsqrt api
  • DXIL.td add the llvm intrinsic to DXIL op lowering map.
  • Builtins.td - add an hlsl builtin for rsqrt.
  • CGBuiltin.cpp add the ir generation for the rsqrt intrinsic.
  • SemaChecking.cpp - reuse the one arg float only checks.
  • IntrinsicsDirectX.td -add an rsqrt intrinsic.

@farzonl farzonl self-assigned this Mar 11, 2024
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen backend:DirectX HLSL HLSL Language Support llvm:ir labels Mar 11, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 11, 2024

@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

Changes

This change implements #70074

  • hlsl_intrinsics.h - add the rsqrt api
  • DXIL.td add the llvm intrinsic to DXIL op lowering map.
  • Builtins.td - add an hlsl builtin for rsqrt.
  • CGBuiltin.cpp add the ir generation for the rsqrt intrinsic.
  • SemaChecking.cpp - reuse the one arg float only checks.
  • IntrinsicsDirectX.td -add an rsqrt intrinsic.

Full diff: https://github.com/llvm/llvm-project/pull/84820.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+32)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+1)
  • (added) clang/test/CodeGenHLSL/builtins/rsqrt.hlsl (+53)
  • (added) clang/test/SemaHLSL/BuiltIns/dot-warning.ll (+49)
  • (added) clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl (+27)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+3)
  • (added) llvm/test/CodeGen/DirectX/rsqrt.ll (+31)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9c703377ca8d3e..de0cfb4e46b8bd 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4590,6 +4590,12 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..d2c83a5e405f42 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18077,6 +18077,14 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/Op0->getType(), Intrinsic::dx_rcp,
         ArrayRef<Value *>{Op0}, nullptr, "dx.rcp");
   }
+  case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+      llvm_unreachable("rsqrt operand must have a float representation");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
+        ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 45f8544392584e..f88aa4a5d5c644 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1153,6 +1153,38 @@ double3 rcp(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rcp)
 double4 rcp(double4);
 
+//===----------------------------------------------------------------------===//
+// rsqrt builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T rsqrt(T x)
+/// \brief RReturns the reciprocal of the square root of the specified value \a x.
+/// \param x The specified input value.
+///
+/// This function uses the following formula: 1 / sqrt(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half rsqrt(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half2 rsqrt(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half3 rsqrt(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half4 rsqrt(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float rsqrt(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float2 rsqrt(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float3 rsqrt(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float4 rsqrt(float4);
+
 //===----------------------------------------------------------------------===//
 // round builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a5f42b630c3fa2..0dafff47ab4040 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5285,6 +5285,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
   case Builtin::BI__builtin_hlsl_elementwise_rcp:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
diff --git a/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
new file mode 100644
index 00000000000000..c87a8c404b08e1
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %dx.rsqrt = call half @llvm.dx.rsqrt.f16(
+// NATIVE_HALF: ret half %dx.rsqrt
+// NO_HALF: define noundef float @"?test_rsqrt_half@@YA$halff@$halff@@Z"(
+// NO_HALF: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// NO_HALF: ret float %dx.rsqrt
+half test_rsqrt_half(half p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <2 x half> @llvm.dx.rsqrt.v2f16
+// NATIVE_HALF: ret <2 x half> %dx.rsqrt
+// NO_HALF: define noundef <2 x float> @
+// NO_HALF: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32(
+// NO_HALF: ret <2 x float> %dx.rsqrt
+half2 test_rsqrt_half2(half2 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <3 x half> @llvm.dx.rsqrt.v3f16
+// NATIVE_HALF: ret <3 x half> %dx.rsqrt
+// NO_HALF: define noundef <3 x float> @
+// NO_HALF: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32(
+// NO_HALF: ret <3 x float> %dx.rsqrt
+half3 test_rsqrt_half3(half3 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <4 x half> @llvm.dx.rsqrt.v4f16
+// NATIVE_HALF: ret <4 x half> %dx.rsqrt
+// NO_HALF: define noundef <4 x float> @
+// NO_HALF: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32(
+// NO_HALF: ret <4 x float> %dx.rsqrt
+half4 test_rsqrt_half4(half4 p0) { return rsqrt(p0); }
+
+// CHECK: define noundef float @
+// CHECK: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// CHECK: ret float %dx.rsqrt
+float test_rsqrt_float(float p0) { return rsqrt(p0); }
+// CHECK: define noundef <2 x float> @
+// CHECK: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32
+// CHECK: ret <2 x float> %dx.rsqrt
+float2 test_rsqrt_float2(float2 p0) { return rsqrt(p0); }
+// CHECK: define noundef <3 x float> @
+// CHECK: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32
+// CHECK: ret <3 x float> %dx.rsqrt
+float3 test_rsqrt_float3(float3 p0) { return rsqrt(p0); }
+// CHECK: define noundef <4 x float> @
+// CHECK: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32
+// CHECK: ret <4 x float> %dx.rsqrt
+float4 test_rsqrt_float4(float4 p0) { return rsqrt(p0); }
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-warning.ll b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
new file mode 100644
index 00000000000000..5ecdde4d70e51f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
@@ -0,0 +1,49 @@
+; ModuleID = 'D:\projects\llvm-project\clang\test\SemaHLSL\BuiltIns\dot-warning.hlsl'
+source_filename = "D:\\projects\\llvm-project\\clang\\test\\SemaHLSL\\BuiltIns\\dot-warning.hlsl"
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.3-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_vector_elem_size_reduction@@YAMT?$__vector@J$01@__clang@@M@Z"(<2 x i64> noundef %p0, float noundef %p1) #0 {
+entry:
+  %p1.addr = alloca float, align 4
+  %p0.addr = alloca <2 x i64>, align 16
+  store float %p1, ptr %p1.addr, align 4
+  store <2 x i64> %p0, ptr %p0.addr, align 16
+  %0 = load <2 x i64>, ptr %p0.addr, align 16
+  %conv = sitofp <2 x i64> %0 to <2 x float>
+  %1 = load float, ptr %p1.addr, align 4
+  %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+  %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+  %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+  ret float %dx.dot
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_int_vector_elem_size_reduction@@YAMT?$__vector@H$01@__clang@@M@Z"(<2 x i32> noundef %p0, float noundef %p1) #0 {
+entry:
+  %p1.addr = alloca float, align 4
+  %p0.addr = alloca <2 x i32>, align 8
+  store float %p1, ptr %p1.addr, align 4
+  store <2 x i32> %p0, ptr %p0.addr, align 8
+  %0 = load <2 x i32>, ptr %p0.addr, align 8
+  %conv = sitofp <2 x i32> %0 to <2 x float>
+  %1 = load float, ptr %p1.addr, align 4
+  %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+  %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+  %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+  ret float %dx.dot
+}
+
+attributes #0 = { noinline nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (https://github.com/farzonl/llvm-project.git f40562c7b4224e00da2ff2e13d175abfaac68532)"}
diff --git a/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
new file mode 100644
index 00000000000000..c74c502bd7a26f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
@@ -0,0 +1,27 @@
+
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float test_too_few_arg() {
+  return __builtin_hlsl_elementwise_rsqrt();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_elementwise_rsqrt(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+float builtin_bool_to_float_type_promotion(bool p1) {
+  return __builtin_hlsl_elementwise_rsqrt(p1);
+  // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+float builtin_rsqrt_int_to_float_promotion(int p1) {
+  return __builtin_hlsl_elementwise_rsqrt(p1);
+  // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
+  return __builtin_hlsl_elementwise_rsqrt(p1);
+  // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 7229292e377a83..366dedda2b3f73 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -37,4 +37,5 @@ def int_dx_lerp :
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9536a01e125bb3..942715f6ad80d3 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -224,6 +224,9 @@ def Exp2 : DXILOpMapping<21, unary, int_exp2,
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.">;
+def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
+                         "Returns the reciprocal of the square root of the specified value."
+                         "rsqrt(x) = 1 / sqrt(x).">;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
                          "within a floating-point type.">;
diff --git a/llvm/test/CodeGen/DirectX/rsqrt.ll b/llvm/test/CodeGen/DirectX/rsqrt.ll
new file mode 100644
index 00000000000000..818b5985422173
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/rsqrt.ll
@@ -0,0 +1,31 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for rsqrt are generated for float and half.
+; CHECK:call float @dx.op.unary.f32(i32 25, float %{{.*}})
+; CHECK:call half @dx.op.unary.f16(i32 25, half %{{.*}})
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @rsqrt_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0)
+  ret float %dx.rsqrt
+}
+
+; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
+declare float @llvm.dx.rsqrt.f32(float) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @rsqrt_half(half noundef %a) #0 {
+entry:
+  %a.addr = alloca half, align 2
+  store half %a, ptr %a.addr, align 2
+  %0 = load half, ptr %a.addr, align 2
+  %dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0)
+  ret half %dx.rsqrt
+}

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 11, 2024

@llvm/pr-subscribers-clang-codegen

Author: Farzon Lotfi (farzonl)

Changes

This change implements #70074

  • hlsl_intrinsics.h - add the rsqrt api
  • DXIL.td add the llvm intrinsic to DXIL op lowering map.
  • Builtins.td - add an hlsl builtin for rsqrt.
  • CGBuiltin.cpp add the ir generation for the rsqrt intrinsic.
  • SemaChecking.cpp - reuse the one arg float only checks.
  • IntrinsicsDirectX.td -add an rsqrt intrinsic.

Full diff: https://github.com/llvm/llvm-project/pull/84820.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+32)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+1)
  • (added) clang/test/CodeGenHLSL/builtins/rsqrt.hlsl (+53)
  • (added) clang/test/SemaHLSL/BuiltIns/dot-warning.ll (+49)
  • (added) clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl (+27)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+3)
  • (added) llvm/test/CodeGen/DirectX/rsqrt.ll (+31)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9c703377ca8d3e..de0cfb4e46b8bd 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4590,6 +4590,12 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..d2c83a5e405f42 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18077,6 +18077,14 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/Op0->getType(), Intrinsic::dx_rcp,
         ArrayRef<Value *>{Op0}, nullptr, "dx.rcp");
   }
+  case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+      llvm_unreachable("rsqrt operand must have a float representation");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
+        ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 45f8544392584e..f88aa4a5d5c644 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1153,6 +1153,38 @@ double3 rcp(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rcp)
 double4 rcp(double4);
 
+//===----------------------------------------------------------------------===//
+// rsqrt builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T rsqrt(T x)
+/// \brief RReturns the reciprocal of the square root of the specified value \a x.
+/// \param x The specified input value.
+///
+/// This function uses the following formula: 1 / sqrt(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half rsqrt(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half2 rsqrt(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half3 rsqrt(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+half4 rsqrt(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float rsqrt(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float2 rsqrt(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float3 rsqrt(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
+float4 rsqrt(float4);
+
 //===----------------------------------------------------------------------===//
 // round builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a5f42b630c3fa2..0dafff47ab4040 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5285,6 +5285,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
   case Builtin::BI__builtin_hlsl_elementwise_rcp:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
diff --git a/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
new file mode 100644
index 00000000000000..c87a8c404b08e1
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %dx.rsqrt = call half @llvm.dx.rsqrt.f16(
+// NATIVE_HALF: ret half %dx.rsqrt
+// NO_HALF: define noundef float @"?test_rsqrt_half@@YA$halff@$halff@@Z"(
+// NO_HALF: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// NO_HALF: ret float %dx.rsqrt
+half test_rsqrt_half(half p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <2 x half> @llvm.dx.rsqrt.v2f16
+// NATIVE_HALF: ret <2 x half> %dx.rsqrt
+// NO_HALF: define noundef <2 x float> @
+// NO_HALF: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32(
+// NO_HALF: ret <2 x float> %dx.rsqrt
+half2 test_rsqrt_half2(half2 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <3 x half> @llvm.dx.rsqrt.v3f16
+// NATIVE_HALF: ret <3 x half> %dx.rsqrt
+// NO_HALF: define noundef <3 x float> @
+// NO_HALF: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32(
+// NO_HALF: ret <3 x float> %dx.rsqrt
+half3 test_rsqrt_half3(half3 p0) { return rsqrt(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %dx.rsqrt = call <4 x half> @llvm.dx.rsqrt.v4f16
+// NATIVE_HALF: ret <4 x half> %dx.rsqrt
+// NO_HALF: define noundef <4 x float> @
+// NO_HALF: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32(
+// NO_HALF: ret <4 x float> %dx.rsqrt
+half4 test_rsqrt_half4(half4 p0) { return rsqrt(p0); }
+
+// CHECK: define noundef float @
+// CHECK: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
+// CHECK: ret float %dx.rsqrt
+float test_rsqrt_float(float p0) { return rsqrt(p0); }
+// CHECK: define noundef <2 x float> @
+// CHECK: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32
+// CHECK: ret <2 x float> %dx.rsqrt
+float2 test_rsqrt_float2(float2 p0) { return rsqrt(p0); }
+// CHECK: define noundef <3 x float> @
+// CHECK: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32
+// CHECK: ret <3 x float> %dx.rsqrt
+float3 test_rsqrt_float3(float3 p0) { return rsqrt(p0); }
+// CHECK: define noundef <4 x float> @
+// CHECK: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32
+// CHECK: ret <4 x float> %dx.rsqrt
+float4 test_rsqrt_float4(float4 p0) { return rsqrt(p0); }
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-warning.ll b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
new file mode 100644
index 00000000000000..5ecdde4d70e51f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot-warning.ll
@@ -0,0 +1,49 @@
+; ModuleID = 'D:\projects\llvm-project\clang\test\SemaHLSL\BuiltIns\dot-warning.hlsl'
+source_filename = "D:\\projects\\llvm-project\\clang\\test\\SemaHLSL\\BuiltIns\\dot-warning.hlsl"
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.3-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_vector_elem_size_reduction@@YAMT?$__vector@J$01@__clang@@M@Z"(<2 x i64> noundef %p0, float noundef %p1) #0 {
+entry:
+  %p1.addr = alloca float, align 4
+  %p0.addr = alloca <2 x i64>, align 16
+  store float %p1, ptr %p1.addr, align 4
+  store <2 x i64> %p0, ptr %p0.addr, align 16
+  %0 = load <2 x i64>, ptr %p0.addr, align 16
+  %conv = sitofp <2 x i64> %0 to <2 x float>
+  %1 = load float, ptr %p1.addr, align 4
+  %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+  %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+  %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+  ret float %dx.dot
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_dot_builtin_int_vector_elem_size_reduction@@YAMT?$__vector@H$01@__clang@@M@Z"(<2 x i32> noundef %p0, float noundef %p1) #0 {
+entry:
+  %p1.addr = alloca float, align 4
+  %p0.addr = alloca <2 x i32>, align 8
+  store float %p1, ptr %p1.addr, align 4
+  store <2 x i32> %p0, ptr %p0.addr, align 8
+  %0 = load <2 x i32>, ptr %p0.addr, align 8
+  %conv = sitofp <2 x i32> %0 to <2 x float>
+  %1 = load float, ptr %p1.addr, align 4
+  %splat.splatinsert = insertelement <2 x float> poison, float %1, i64 0
+  %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+  %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+  ret float %dx.dot
+}
+
+attributes #0 = { noinline nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (https://github.com/farzonl/llvm-project.git f40562c7b4224e00da2ff2e13d175abfaac68532)"}
diff --git a/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
new file mode 100644
index 00000000000000..c74c502bd7a26f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
@@ -0,0 +1,27 @@
+
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float test_too_few_arg() {
+  return __builtin_hlsl_elementwise_rsqrt();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_elementwise_rsqrt(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+float builtin_bool_to_float_type_promotion(bool p1) {
+  return __builtin_hlsl_elementwise_rsqrt(p1);
+  // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+float builtin_rsqrt_int_to_float_promotion(int p1) {
+  return __builtin_hlsl_elementwise_rsqrt(p1);
+  // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
+  return __builtin_hlsl_elementwise_rsqrt(p1);
+  // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 7229292e377a83..366dedda2b3f73 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -37,4 +37,5 @@ def int_dx_lerp :
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9536a01e125bb3..942715f6ad80d3 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -224,6 +224,9 @@ def Exp2 : DXILOpMapping<21, unary, int_exp2,
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.">;
+def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
+                         "Returns the reciprocal of the square root of the specified value."
+                         "rsqrt(x) = 1 / sqrt(x).">;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
                          "within a floating-point type.">;
diff --git a/llvm/test/CodeGen/DirectX/rsqrt.ll b/llvm/test/CodeGen/DirectX/rsqrt.ll
new file mode 100644
index 00000000000000..818b5985422173
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/rsqrt.ll
@@ -0,0 +1,31 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for rsqrt are generated for float and half.
+; CHECK:call float @dx.op.unary.f32(i32 25, float %{{.*}})
+; CHECK:call half @dx.op.unary.f16(i32 25, half %{{.*}})
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @rsqrt_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0)
+  ret float %dx.rsqrt
+}
+
+; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
+declare float @llvm.dx.rsqrt.f32(float) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @rsqrt_half(half noundef %a) #0 {
+entry:
+  %a.addr = alloca half, align 2
+  store half %a, ptr %a.addr, align 2
+  %0 = load half, ptr %a.addr, align 2
+  %dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0)
+  ret half %dx.rsqrt
+}

Copy link

github-actions bot commented Mar 11, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

clang/lib/Headers/hlsl/hlsl_intrinsics.h Outdated Show resolved Hide resolved
clang/lib/Headers/hlsl/hlsl_intrinsics.h Outdated Show resolved Hide resolved
clang/test/SemaHLSL/BuiltIns/dot-warning.ll Outdated Show resolved Hide resolved
clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl Outdated Show resolved Hide resolved
@farzonl farzonl linked an issue Mar 12, 2024 that may be closed by this pull request
llvm/test/CodeGen/DirectX/rsqrt.ll Outdated Show resolved Hide resolved
@farzonl farzonl merged commit 8f9ee39 into llvm:main Mar 14, 2024
3 of 4 checks passed
@farzonl farzonl deleted the hlsl-rsqrt-intrinsic branch March 14, 2024 20:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:X86 clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[HLSL] implement rsqrt intrinsic
5 participants