Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL][DXIL] implement sqrt intrinsic #86560

Merged
merged 1 commit into from
Mar 25, 2024

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Mar 25, 2024

completes #86187

  • fix hlsl_intrinsic to cover the correct cases
  • move to using __builtin_elementwise_sqrt
  • add lowering of Intrinsic::sqrt to dxilop 24.

completes llvm#86187
- fix hlsl_intrinsic to cover the correct cases
- move to using `__builtin_elementwise_sqrt`
- add lowering of `Intrinsic::sqrt` to dxilop 24.
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics backend:DirectX HLSL HLSL Language Support labels Mar 25, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 25, 2024

@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-x86

Author: Farzon Lotfi (farzonl)

Changes

completes #86187

  • fix hlsl_intrinsic to cover the correct cases
  • move to using __builtin_elementwise_sqrt
  • add lowering of Intrinsic::sqrt to dxilop 24.

Full diff: https://github.com/llvm/llvm-project/pull/86560.diff

5 Files Affected:

  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+20-8)
  • (modified) clang/test/CodeGenHLSL/builtins/sqrt.hlsl (+51-27)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+4)
  • (added) llvm/test/CodeGen/DirectX/sqrt.ll (+20)
  • (added) llvm/test/CodeGen/DirectX/sqrt_error.ll (+10)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index d9f27c0db57ce4..fcb64fde1b91f1 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1366,14 +1366,26 @@ float4 sin(float4);
 /// \param Val The input value.
 
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_sqrtf16)
-half sqrt(half In);
-
-_HLSL_BUILTIN_ALIAS(__builtin_sqrtf)
-float sqrt(float In);
-
-_HLSL_BUILTIN_ALIAS(__builtin_sqrt)
-double sqrt(double In);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half sqrt(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half2 sqrt(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half3 sqrt(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half4 sqrt(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float sqrt(float);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float2 sqrt(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float3 sqrt(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float4 sqrt(float4);
 
 //===----------------------------------------------------------------------===//
 // trunc builtins
diff --git a/clang/test/CodeGenHLSL/builtins/sqrt.hlsl b/clang/test/CodeGenHLSL/builtins/sqrt.hlsl
index 2c2a09617cf86a..adbbf69a8e0685 100644
--- a/clang/test/CodeGenHLSL/builtins/sqrt.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/sqrt.hlsl
@@ -1,29 +1,53 @@
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
-// RUN:   dxil-pc-shadermodel6.2-library %s -fnative-half-type \
-// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
-using hlsl::sqrt;
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %{{.*}} = call half @llvm.sqrt.f16(
+// NATIVE_HALF: ret half %{{.*}}
+// NO_HALF: define noundef float @"?test_sqrt_half@@YA$halff@$halff@@Z"(
+// NO_HALF: %{{.*}} = call float @llvm.sqrt.f32(
+// NO_HALF: ret float %{{.*}}
+half test_sqrt_half(half p0) { return sqrt(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %{{.*}} = call <2 x half> @llvm.sqrt.v2f16
+// NATIVE_HALF: ret <2 x half> %{{.*}}
+// NO_HALF: define noundef <2 x float> @
+// NO_HALF: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32(
+// NO_HALF: ret <2 x float> %{{.*}}
+half2 test_sqrt_half2(half2 p0) { return sqrt(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %{{.*}} = call <3 x half> @llvm.sqrt.v3f16
+// NATIVE_HALF: ret <3 x half> %{{.*}}
+// NO_HALF: define noundef <3 x float> @
+// NO_HALF: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32(
+// NO_HALF: ret <3 x float> %{{.*}}
+half3 test_sqrt_half3(half3 p0) { return sqrt(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %{{.*}} = call <4 x half> @llvm.sqrt.v4f16
+// NATIVE_HALF: ret <4 x half> %{{.*}}
+// NO_HALF: define noundef <4 x float> @
+// NO_HALF: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32(
+// NO_HALF: ret <4 x float> %{{.*}}
+half4 test_sqrt_half4(half4 p0) { return sqrt(p0); }
 
-double sqrt_d(double x)
-{
-  return sqrt(x);
-}
-
-// CHECK: define noundef double @"?sqrt_d@@YANN@Z"(
-// CHECK: call double @llvm.sqrt.f64(double %0)
-
-float sqrt_f(float x)
-{
-  return sqrt(x);
-}
-
-// CHECK: define noundef float @"?sqrt_f@@YAMM@Z"(
-// CHECK: call float @llvm.sqrt.f32(float %0)
-
-half sqrt_h(half x)
-{
-  return sqrt(x);
-}
-
-// CHECK: define noundef half @"?sqrt_h@@YA$f16@$f16@@Z"(
-// CHECK: call half @llvm.sqrt.f16(half %0)
+// CHECK: define noundef float @
+// CHECK: %{{.*}} = call float @llvm.sqrt.f32(
+// CHECK: ret float %{{.*}}
+float test_sqrt_float(float p0) { return sqrt(p0); }
+// CHECK: define noundef <2 x float> @
+// CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32
+// CHECK: ret <2 x float> %{{.*}}
+float2 test_sqrt_float2(float2 p0) { return sqrt(p0); }
+// CHECK: define noundef <3 x float> @
+// CHECK: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32
+// CHECK: ret <3 x float> %{{.*}}
+float3 test_sqrt_float3(float3 p0) { return sqrt(p0); }
+// CHECK: define noundef <4 x float> @
+// CHECK: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32
+// CHECK: ret <4 x float> %{{.*}}
+float4 test_sqrt_float4(float4 p0) { return sqrt(p0); }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index f7e69ebae15b6c..572d3323ebe75f 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -274,6 +274,10 @@ def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.",
                          [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+def Sqrt : DXILOpMapping<24, unary, int_sqrt,
+                         "Returns the square root of the specified floating-point"
+                         "value, per component.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
                          "Returns the reciprocal of the square root of the specified value."
                          "rsqrt(x) = 1 / sqrt(x).",
diff --git a/llvm/test/CodeGen/DirectX/sqrt.ll b/llvm/test/CodeGen/DirectX/sqrt.ll
new file mode 100644
index 00000000000000..76a572efd20557
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sqrt.ll
@@ -0,0 +1,20 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for sqrt are generated for float and half.
+
+define noundef float @sqrt_float(float noundef %a) #0 {
+entry:
+; CHECK:call float @dx.op.unary.f32(i32 24, float %{{.*}})
+  %elt.sqrt = call float @llvm.sqrt.f32(float %a)
+  ret float %elt.sqrt
+}
+
+define noundef half @sqrt_half(half noundef %a) #0 {
+entry:
+; CHECK:call half @dx.op.unary.f16(i32 24, half %{{.*}})
+  %elt.sqrt = call half @llvm.sqrt.f16(half %a)
+  ret half %elt.sqrt
+}
+
+declare half @llvm.sqrt.f16(half)
+declare float @llvm.sqrt.f32(float)
diff --git a/llvm/test/CodeGen/DirectX/sqrt_error.ll b/llvm/test/CodeGen/DirectX/sqrt_error.ll
new file mode 100644
index 00000000000000..fffa2e19b80fa9
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sqrt_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation sqrt does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef double @sqrt_double(double noundef %a) {
+entry:
+  %elt.sqrt = call double @llvm.sqrt.f64(double %a)
+  ret double %elt.sqrt
+}

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

@farzonl farzonl linked an issue Mar 25, 2024 that may be closed by this pull request
@farzonl farzonl merged commit 4cea2d0 into llvm:main Mar 25, 2024
10 checks passed
@farzonl farzonl deleted the hlsl-fix-sqrt-intrinsic branch March 25, 2024 22:02
farzonl added a commit to farzonl/llvm-project that referenced this pull request Mar 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[HLSL] implement sqrt intrinsic
4 participants