diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index d9f27c0db57ce..fcb64fde1b91f 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -1366,14 +1366,26 @@ float4 sin(float4); /// \param Val The input value. _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) -_HLSL_BUILTIN_ALIAS(__builtin_sqrtf16) -half sqrt(half In); - -_HLSL_BUILTIN_ALIAS(__builtin_sqrtf) -float sqrt(float In); - -_HLSL_BUILTIN_ALIAS(__builtin_sqrt) -double sqrt(double In); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +half sqrt(half); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +half2 sqrt(half2); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +half3 sqrt(half3); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +half4 sqrt(half4); + +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +float sqrt(float); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +float2 sqrt(float2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +float3 sqrt(float3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt) +float4 sqrt(float4); //===----------------------------------------------------------------------===// // trunc builtins diff --git a/clang/test/CodeGenHLSL/builtins/sqrt.hlsl b/clang/test/CodeGenHLSL/builtins/sqrt.hlsl index 2c2a09617cf86..adbbf69a8e068 100644 --- a/clang/test/CodeGenHLSL/builtins/sqrt.hlsl +++ b/clang/test/CodeGenHLSL/builtins/sqrt.hlsl @@ -1,29 +1,53 @@ -// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ -// RUN: dxil-pc-shadermodel6.2-library %s -fnative-half-type \ -// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ +// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ +// RUN: --check-prefixes=CHECK,NATIVE_HALF +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF -using hlsl::sqrt; +// NATIVE_HALF: define noundef half @ +// NATIVE_HALF: %{{.*}} = call half @llvm.sqrt.f16( +// NATIVE_HALF: ret half %{{.*}} +// NO_HALF: define noundef float @"?test_sqrt_half@@YA$halff@$halff@@Z"( +// NO_HALF: %{{.*}} = call float @llvm.sqrt.f32( +// NO_HALF: ret float %{{.*}} +half test_sqrt_half(half p0) { return sqrt(p0); } +// NATIVE_HALF: define noundef <2 x half> @ +// NATIVE_HALF: %{{.*}} = call <2 x half> @llvm.sqrt.v2f16 +// NATIVE_HALF: ret <2 x half> %{{.*}} +// NO_HALF: define noundef <2 x float> @ +// NO_HALF: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32( +// NO_HALF: ret <2 x float> %{{.*}} +half2 test_sqrt_half2(half2 p0) { return sqrt(p0); } +// NATIVE_HALF: define noundef <3 x half> @ +// NATIVE_HALF: %{{.*}} = call <3 x half> @llvm.sqrt.v3f16 +// NATIVE_HALF: ret <3 x half> %{{.*}} +// NO_HALF: define noundef <3 x float> @ +// NO_HALF: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32( +// NO_HALF: ret <3 x float> %{{.*}} +half3 test_sqrt_half3(half3 p0) { return sqrt(p0); } +// NATIVE_HALF: define noundef <4 x half> @ +// NATIVE_HALF: %{{.*}} = call <4 x half> @llvm.sqrt.v4f16 +// NATIVE_HALF: ret <4 x half> %{{.*}} +// NO_HALF: define noundef <4 x float> @ +// NO_HALF: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32( +// NO_HALF: ret <4 x float> %{{.*}} +half4 test_sqrt_half4(half4 p0) { return sqrt(p0); } -double sqrt_d(double x) -{ - return sqrt(x); -} - -// CHECK: define noundef double @"?sqrt_d@@YANN@Z"( -// CHECK: call double @llvm.sqrt.f64(double %0) - -float sqrt_f(float x) -{ - return sqrt(x); -} - -// CHECK: define noundef float @"?sqrt_f@@YAMM@Z"( -// CHECK: call float @llvm.sqrt.f32(float %0) - -half sqrt_h(half x) -{ - return sqrt(x); -} - -// CHECK: define noundef half @"?sqrt_h@@YA$f16@$f16@@Z"( -// CHECK: call half @llvm.sqrt.f16(half %0) +// CHECK: define noundef float @ +// CHECK: %{{.*}} = call float @llvm.sqrt.f32( +// CHECK: ret float %{{.*}} +float test_sqrt_float(float p0) { return sqrt(p0); } +// CHECK: define noundef <2 x float> @ +// CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32 +// CHECK: ret <2 x float> %{{.*}} +float2 test_sqrt_float2(float2 p0) { return sqrt(p0); } +// CHECK: define noundef <3 x float> @ +// CHECK: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32 +// CHECK: ret <3 x float> %{{.*}} +float3 test_sqrt_float3(float3 p0) { return sqrt(p0); } +// CHECK: define noundef <4 x float> @ +// CHECK: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32 +// CHECK: ret <4 x float> %{{.*}} +float4 test_sqrt_float4(float4 p0) { return sqrt(p0); } diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index f7e69ebae15b6..572d3323ebe75 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -274,6 +274,10 @@ def Frac : DXILOpMapping<22, unary, int_dx_frac, "Returns a fraction from 0 to 1 that represents the " "decimal part of the input.", [llvm_halforfloat_ty, LLVMMatchType<0>]>; +def Sqrt : DXILOpMapping<24, unary, int_sqrt, + "Returns the square root of the specified floating-point" + "value, per component.", + [llvm_halforfloat_ty, LLVMMatchType<0>]>; def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt, "Returns the reciprocal of the square root of the specified value." "rsqrt(x) = 1 / sqrt(x).", diff --git a/llvm/test/CodeGen/DirectX/sqrt.ll b/llvm/test/CodeGen/DirectX/sqrt.ll new file mode 100644 index 0000000000000..76a572efd2055 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/sqrt.ll @@ -0,0 +1,20 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for sqrt are generated for float and half. + +define noundef float @sqrt_float(float noundef %a) #0 { +entry: +; CHECK:call float @dx.op.unary.f32(i32 24, float %{{.*}}) + %elt.sqrt = call float @llvm.sqrt.f32(float %a) + ret float %elt.sqrt +} + +define noundef half @sqrt_half(half noundef %a) #0 { +entry: +; CHECK:call half @dx.op.unary.f16(i32 24, half %{{.*}}) + %elt.sqrt = call half @llvm.sqrt.f16(half %a) + ret half %elt.sqrt +} + +declare half @llvm.sqrt.f16(half) +declare float @llvm.sqrt.f32(float) diff --git a/llvm/test/CodeGen/DirectX/sqrt_error.ll b/llvm/test/CodeGen/DirectX/sqrt_error.ll new file mode 100644 index 0000000000000..fffa2e19b80fa --- /dev/null +++ b/llvm/test/CodeGen/DirectX/sqrt_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation sqrt does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload Type + +define noundef double @sqrt_double(double noundef %a) { +entry: + %elt.sqrt = call double @llvm.sqrt.f64(double %a) + ret double %elt.sqrt +}