Skip to content

Commit

Permalink
[DAGCombiner] Require ninf for sqrt recip estimation
Browse files Browse the repository at this point in the history
Currently, DAG combiner uses (fmul (rsqrt x) x) to estimate square
root of x. However, this method would return NaN if x is +Inf, which
is incorrect.

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D76853
  • Loading branch information
ecnelises committed Apr 1, 2020
1 parent 862766e commit 95bcab8
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 92 deletions.
8 changes: 6 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -13109,8 +13109,12 @@ SDValue DAGCombiner::visitFREM(SDNode *N) {

SDValue DAGCombiner::visitFSQRT(SDNode *N) {
SDNodeFlags Flags = N->getFlags();
if (!DAG.getTarget().Options.UnsafeFPMath &&
!Flags.hasApproximateFuncs())
const TargetOptions &Options = DAG.getTarget().Options;

// Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
// sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
if ((!Options.UnsafeFPMath && !Flags.hasApproximateFuncs()) ||
(!Options.NoInfsFPMath && !Flags.hasNoInfs()))
return SDValue();

SDValue N0 = N->getOperand(0);
Expand Down
35 changes: 31 additions & 4 deletions llvm/test/CodeGen/NVPTX/fast-math.ll
Expand Up @@ -13,14 +13,23 @@ define float @sqrt_div(float %a, float %b) {
}

; CHECK-LABEL: sqrt_div_fast(
; CHECK: sqrt.approx.f32
; CHECK: sqrt.rn.f32
; CHECK: div.approx.f32
define float @sqrt_div_fast(float %a, float %b) #0 {
%t1 = tail call float @llvm.sqrt.f32(float %a)
%t2 = fdiv float %t1, %b
ret float %t2
}

; CHECK-LABEL: sqrt_div_fast_ninf(
; CHECK: sqrt.approx.f32
; CHECK: div.approx.f32
define float @sqrt_div_fast_ninf(float %a, float %b) #0 {
%t1 = tail call ninf float @llvm.sqrt.f32(float %a)
%t2 = fdiv float %t1, %b
ret float %t2
}

; CHECK-LABEL: sqrt_div_ftz(
; CHECK: sqrt.rn.ftz.f32
; CHECK: div.rn.ftz.f32
Expand All @@ -31,27 +40,45 @@ define float @sqrt_div_ftz(float %a, float %b) #1 {
}

; CHECK-LABEL: sqrt_div_fast_ftz(
; CHECK: sqrt.approx.ftz.f32
; CHECK: sqrt.rn.ftz.f32
; CHECK: div.approx.ftz.f32
define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 {
%t1 = tail call float @llvm.sqrt.f32(float %a)
%t2 = fdiv float %t1, %b
ret float %t2
}

; CHECK-LABEL: sqrt_div_fast_ftz_ninf(
; CHECK: sqrt.approx.ftz.f32
; CHECK: div.approx.ftz.f32
define float @sqrt_div_fast_ftz_ninf(float %a, float %b) #0 #1 {
%t1 = tail call ninf float @llvm.sqrt.f32(float %a)
%t2 = fdiv float %t1, %b
ret float %t2
}

; There are no fast-math or ftz versions of sqrt and div for f64. We use
; reciprocal(rsqrt(x)) for sqrt(x), and emit a vanilla divide.

; CHECK-LABEL: sqrt_div_fast_ftz_f64(
; CHECK: rsqrt.approx.f64
; CHECK: rcp.approx.ftz.f64
; CHECK: sqrt.rn.f64
; CHECK: div.rn.f64
define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 {
%t1 = tail call double @llvm.sqrt.f64(double %a)
%t2 = fdiv double %t1, %b
ret double %t2
}

; CHECK-LABEL: sqrt_div_fast_ftz_f64_ninf(
; CHECK: rsqrt.approx.f64
; CHECK: rcp.approx.ftz.f64
; CHECK: div.rn.f64
define double @sqrt_div_fast_ftz_f64_ninf(double %a, double %b) #0 #1 {
%t1 = tail call ninf double @llvm.sqrt.f64(double %a)
%t2 = fdiv double %t1, %b
ret double %t2
}

; CHECK-LABEL: rsqrt(
; CHECK-NOT: rsqrt.approx
; CHECK: sqrt.rn.f32
Expand Down
72 changes: 64 additions & 8 deletions llvm/test/CodeGen/NVPTX/sqrt-approx.ll
Expand Up @@ -45,35 +45,63 @@ define double @test_rsqrt64_ftz(double %a) #0 #1 {

; CHECK-LABEL test_sqrt32
define float @test_sqrt32(float %a) #0 {
; CHECK: sqrt.approx.f32
; CHECK: sqrt.rn.f32
%ret = tail call float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL test_sqrt32_ninf
define float @test_sqrt32_ninf(float %a) #0 {
; CHECK: sqrt.approx.f32
%ret = tail call ninf float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL test_sqrt_ftz
define float @test_sqrt_ftz(float %a) #0 #1 {
; CHECK: sqrt.approx.ftz.f32
; CHECK: sqrt.rn.ftz.f32
%ret = tail call float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL test_sqrt_ftz_ninf
define float @test_sqrt_ftz_ninf(float %a) #0 #1 {
; CHECK: sqrt.approx.ftz.f32
%ret = tail call ninf float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL test_sqrt64
define double @test_sqrt64(double %a) #0 {
; CHECK: sqrt.rn.f64
%ret = tail call double @llvm.sqrt.f64(double %a)
ret double %ret
}

; CHECK-LABEL test_sqrt64_ninf
define double @test_sqrt64_ninf(double %a) #0 {
; There's no sqrt.approx.f64 instruction; we emit
; reciprocal(rsqrt.approx.f64(x)). There's no non-ftz approximate reciprocal,
; so we just use the ftz version.
; CHECK: rsqrt.approx.f64
; CHECK: rcp.approx.ftz.f64
%ret = tail call double @llvm.sqrt.f64(double %a)
%ret = tail call ninf double @llvm.sqrt.f64(double %a)
ret double %ret
}

; CHECK-LABEL test_sqrt64_ftz
define double @test_sqrt64_ftz(double %a) #0 #1 {
; CHECK: sqrt.rn.f64
%ret = tail call double @llvm.sqrt.f64(double %a)
ret double %ret
}

; CHECK-LABEL test_sqrt64_ftz_ninf
define double @test_sqrt64_ftz_ninf(double %a) #0 #1 {
; There's no sqrt.approx.ftz.f64 instruction; we just use the non-ftz version.
; CHECK: rsqrt.approx.f64
; CHECK: rcp.approx.ftz.f64
%ret = tail call double @llvm.sqrt.f64(double %a)
%ret = tail call ninf double @llvm.sqrt.f64(double %a)
ret double %ret
}

Expand All @@ -92,11 +120,18 @@ define float @test_rsqrt32_refined(float %a) #0 #2 {

; CHECK-LABEL: test_sqrt32_refined
define float @test_sqrt32_refined(float %a) #0 #2 {
; CHECK: rsqrt.approx.f32
; CHECK: sqrt.rn.f32
%ret = tail call float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL: test_sqrt32_refined_ninf
define float @test_sqrt32_refined_ninf(float %a) #0 #2 {
; CHECK: rsqrt.approx.f32
%ret = tail call ninf float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL: test_rsqrt64_refined
define double @test_rsqrt64_refined(double %a) #0 #2 {
; CHECK: rsqrt.approx.f64
Expand All @@ -107,11 +142,18 @@ define double @test_rsqrt64_refined(double %a) #0 #2 {

; CHECK-LABEL: test_sqrt64_refined
define double @test_sqrt64_refined(double %a) #0 #2 {
; CHECK: rsqrt.approx.f64
; CHECK: sqrt.rn.f64
%ret = tail call double @llvm.sqrt.f64(double %a)
ret double %ret
}

; CHECK-LABEL: test_sqrt64_refined_ninf
define double @test_sqrt64_refined_ninf(double %a) #0 #2 {
; CHECK: rsqrt.approx.f64
%ret = tail call ninf double @llvm.sqrt.f64(double %a)
ret double %ret
}

; -- refined sqrt and rsqrt with ftz enabled --

; CHECK-LABEL: test_rsqrt32_refined_ftz
Expand All @@ -124,11 +166,18 @@ define float @test_rsqrt32_refined_ftz(float %a) #0 #1 #2 {

; CHECK-LABEL: test_sqrt32_refined_ftz
define float @test_sqrt32_refined_ftz(float %a) #0 #1 #2 {
; CHECK: rsqrt.approx.ftz.f32
; CHECK: sqrt.rn.ftz.f32
%ret = tail call float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL: test_sqrt32_refined_ftz_ninf
define float @test_sqrt32_refined_ftz_ninf(float %a) #0 #1 #2 {
; CHECK: rsqrt.approx.ftz.f32
%ret = tail call ninf float @llvm.sqrt.f32(float %a)
ret float %ret
}

; CHECK-LABEL: test_rsqrt64_refined_ftz
define double @test_rsqrt64_refined_ftz(double %a) #0 #1 #2 {
; There's no rsqrt.approx.ftz.f64, so we just use the non-ftz version.
Expand All @@ -140,11 +189,18 @@ define double @test_rsqrt64_refined_ftz(double %a) #0 #1 #2 {

; CHECK-LABEL: test_sqrt64_refined_ftz
define double @test_sqrt64_refined_ftz(double %a) #0 #1 #2 {
; CHECK: rsqrt.approx.f64
; CHECK: sqrt.rn.f64
%ret = tail call double @llvm.sqrt.f64(double %a)
ret double %ret
}

; CHECK-LABEL: test_sqrt64_refined_ftz_ninf
define double @test_sqrt64_refined_ftz_ninf(double %a) #0 #1 #2 {
; CHECK: rsqrt.approx.f64
%ret = tail call ninf double @llvm.sqrt.f64(double %a)
ret double %ret
}

attributes #0 = { "unsafe-fp-math" = "true" }
attributes #1 = { "denormal-fp-math-f32" = "preserve-sign,preserve-sign" }
attributes #2 = { "reciprocal-estimates" = "rsqrtf:1,rsqrtd:1,sqrtf:1,sqrtd:1" }

0 comments on commit 95bcab8

Please sign in to comment.